Commit
·
6fa895d
1
Parent(s):
3f2f51a
Upload 77 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- sd/stable-diffusion-webui/modules/call_queue.py +109 -0
- sd/stable-diffusion-webui/modules/codeformer_model.py +143 -0
- sd/stable-diffusion-webui/modules/deepbooru.py +99 -0
- sd/stable-diffusion-webui/modules/deepbooru_model.py +678 -0
- sd/stable-diffusion-webui/modules/devices.py +152 -0
- sd/stable-diffusion-webui/modules/errors.py +43 -0
- sd/stable-diffusion-webui/modules/esrgan_model.py +233 -0
- sd/stable-diffusion-webui/modules/esrgan_model_arch.py +464 -0
- sd/stable-diffusion-webui/modules/extensions.py +107 -0
- sd/stable-diffusion-webui/modules/extra_networks.py +147 -0
- sd/stable-diffusion-webui/modules/extra_networks_hypernet.py +27 -0
- sd/stable-diffusion-webui/modules/extras.py +258 -0
- sd/stable-diffusion-webui/modules/face_restoration.py +19 -0
- sd/stable-diffusion-webui/modules/generation_parameters_copypaste.py +402 -0
- sd/stable-diffusion-webui/modules/gfpgan_model.py +116 -0
- sd/stable-diffusion-webui/modules/hashes.py +91 -0
- sd/stable-diffusion-webui/modules/images.py +669 -0
- sd/stable-diffusion-webui/modules/img2img.py +184 -0
- sd/stable-diffusion-webui/modules/import_hook.py +5 -0
- sd/stable-diffusion-webui/modules/interrogate.py +227 -0
- sd/stable-diffusion-webui/modules/localization.py +37 -0
- sd/stable-diffusion-webui/modules/lowvram.py +96 -0
- sd/stable-diffusion-webui/modules/mac_specific.py +53 -0
- sd/stable-diffusion-webui/modules/masking.py +99 -0
- sd/stable-diffusion-webui/modules/memmon.py +88 -0
- sd/stable-diffusion-webui/modules/modelloader.py +172 -0
- sd/stable-diffusion-webui/modules/ngrok.py +26 -0
- sd/stable-diffusion-webui/modules/paths.py +62 -0
- sd/stable-diffusion-webui/modules/postprocessing.py +103 -0
- sd/stable-diffusion-webui/modules/processing.py +1056 -0
- sd/stable-diffusion-webui/modules/progress.py +99 -0
- sd/stable-diffusion-webui/modules/prompt_parser.py +373 -0
- sd/stable-diffusion-webui/modules/realesrgan_model.py +129 -0
- sd/stable-diffusion-webui/modules/safe.py +192 -0
- sd/stable-diffusion-webui/modules/script_callbacks.py +359 -0
- sd/stable-diffusion-webui/modules/script_loading.py +32 -0
- sd/stable-diffusion-webui/modules/scripts.py +501 -0
- sd/stable-diffusion-webui/modules/scripts_auto_postprocessing.py +42 -0
- sd/stable-diffusion-webui/modules/scripts_postprocessing.py +152 -0
- sd/stable-diffusion-webui/modules/sd_disable_initialization.py +93 -0
- sd/stable-diffusion-webui/modules/sd_hijack.py +264 -0
- sd/stable-diffusion-webui/modules/sd_hijack_checkpoint.py +46 -0
- sd/stable-diffusion-webui/modules/sd_hijack_clip.py +317 -0
- sd/stable-diffusion-webui/modules/sd_hijack_clip_old.py +81 -0
- sd/stable-diffusion-webui/modules/sd_hijack_inpainting.py +103 -0
- sd/stable-diffusion-webui/modules/sd_hijack_ip2p.py +13 -0
- sd/stable-diffusion-webui/modules/sd_hijack_open_clip.py +37 -0
- sd/stable-diffusion-webui/modules/sd_hijack_optimizations.py +444 -0
- sd/stable-diffusion-webui/modules/sd_hijack_unet.py +79 -0
- sd/stable-diffusion-webui/modules/sd_hijack_utils.py +28 -0
sd/stable-diffusion-webui/modules/call_queue.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
import sys
|
3 |
+
import threading
|
4 |
+
import traceback
|
5 |
+
import time
|
6 |
+
|
7 |
+
from modules import shared, progress
|
8 |
+
|
9 |
+
queue_lock = threading.Lock()
|
10 |
+
|
11 |
+
|
12 |
+
def wrap_queued_call(func):
|
13 |
+
def f(*args, **kwargs):
|
14 |
+
with queue_lock:
|
15 |
+
res = func(*args, **kwargs)
|
16 |
+
|
17 |
+
return res
|
18 |
+
|
19 |
+
return f
|
20 |
+
|
21 |
+
|
22 |
+
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
23 |
+
def f(*args, **kwargs):
|
24 |
+
|
25 |
+
# if the first argument is a string that says "task(...)", it is treated as a job id
|
26 |
+
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
|
27 |
+
id_task = args[0]
|
28 |
+
progress.add_task_to_queue(id_task)
|
29 |
+
else:
|
30 |
+
id_task = None
|
31 |
+
|
32 |
+
with queue_lock:
|
33 |
+
shared.state.begin()
|
34 |
+
progress.start_task(id_task)
|
35 |
+
|
36 |
+
try:
|
37 |
+
res = func(*args, **kwargs)
|
38 |
+
finally:
|
39 |
+
progress.finish_task(id_task)
|
40 |
+
|
41 |
+
shared.state.end()
|
42 |
+
|
43 |
+
return res
|
44 |
+
|
45 |
+
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
|
46 |
+
|
47 |
+
|
48 |
+
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
49 |
+
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
50 |
+
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
51 |
+
if run_memmon:
|
52 |
+
shared.mem_mon.monitor()
|
53 |
+
t = time.perf_counter()
|
54 |
+
|
55 |
+
try:
|
56 |
+
res = list(func(*args, **kwargs))
|
57 |
+
except Exception as e:
|
58 |
+
# When printing out our debug argument list, do not print out more than a MB of text
|
59 |
+
max_debug_str_len = 131072 # (1024*1024)/8
|
60 |
+
|
61 |
+
print("Error completing request", file=sys.stderr)
|
62 |
+
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
63 |
+
print(argStr[:max_debug_str_len], file=sys.stderr)
|
64 |
+
if len(argStr) > max_debug_str_len:
|
65 |
+
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
66 |
+
|
67 |
+
print(traceback.format_exc(), file=sys.stderr)
|
68 |
+
|
69 |
+
shared.state.job = ""
|
70 |
+
shared.state.job_count = 0
|
71 |
+
|
72 |
+
if extra_outputs_array is None:
|
73 |
+
extra_outputs_array = [None, '']
|
74 |
+
|
75 |
+
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
|
76 |
+
|
77 |
+
shared.state.skipped = False
|
78 |
+
shared.state.interrupted = False
|
79 |
+
shared.state.job_count = 0
|
80 |
+
|
81 |
+
if not add_stats:
|
82 |
+
return tuple(res)
|
83 |
+
|
84 |
+
elapsed = time.perf_counter() - t
|
85 |
+
elapsed_m = int(elapsed // 60)
|
86 |
+
elapsed_s = elapsed % 60
|
87 |
+
elapsed_text = f"{elapsed_s:.2f}s"
|
88 |
+
if elapsed_m > 0:
|
89 |
+
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
90 |
+
|
91 |
+
if run_memmon:
|
92 |
+
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
93 |
+
active_peak = mem_stats['active_peak']
|
94 |
+
reserved_peak = mem_stats['reserved_peak']
|
95 |
+
sys_peak = mem_stats['system_peak']
|
96 |
+
sys_total = mem_stats['total']
|
97 |
+
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
|
98 |
+
|
99 |
+
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
100 |
+
else:
|
101 |
+
vram_html = ''
|
102 |
+
|
103 |
+
# last item is always HTML
|
104 |
+
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
105 |
+
|
106 |
+
return tuple(res)
|
107 |
+
|
108 |
+
return f
|
109 |
+
|
sd/stable-diffusion-webui/modules/codeformer_model.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import modules.face_restoration
|
9 |
+
import modules.shared
|
10 |
+
from modules import shared, devices, modelloader
|
11 |
+
from modules.paths import models_path
|
12 |
+
|
13 |
+
# codeformer people made a choice to include modified basicsr library to their project which makes
|
14 |
+
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
15 |
+
# I am making a choice to include some files from codeformer to work around this issue.
|
16 |
+
model_dir = "Codeformer"
|
17 |
+
model_path = os.path.join(models_path, model_dir)
|
18 |
+
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
19 |
+
|
20 |
+
have_codeformer = False
|
21 |
+
codeformer = None
|
22 |
+
|
23 |
+
|
24 |
+
def setup_model(dirname):
|
25 |
+
global model_path
|
26 |
+
if not os.path.exists(model_path):
|
27 |
+
os.makedirs(model_path)
|
28 |
+
|
29 |
+
path = modules.paths.paths.get("CodeFormer", None)
|
30 |
+
if path is None:
|
31 |
+
return
|
32 |
+
|
33 |
+
try:
|
34 |
+
from torchvision.transforms.functional import normalize
|
35 |
+
from modules.codeformer.codeformer_arch import CodeFormer
|
36 |
+
from basicsr.utils.download_util import load_file_from_url
|
37 |
+
from basicsr.utils import imwrite, img2tensor, tensor2img
|
38 |
+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
39 |
+
from facelib.detection.retinaface import retinaface
|
40 |
+
from modules.shared import cmd_opts
|
41 |
+
|
42 |
+
net_class = CodeFormer
|
43 |
+
|
44 |
+
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
|
45 |
+
def name(self):
|
46 |
+
return "CodeFormer"
|
47 |
+
|
48 |
+
def __init__(self, dirname):
|
49 |
+
self.net = None
|
50 |
+
self.face_helper = None
|
51 |
+
self.cmd_dir = dirname
|
52 |
+
|
53 |
+
def create_models(self):
|
54 |
+
|
55 |
+
if self.net is not None and self.face_helper is not None:
|
56 |
+
self.net.to(devices.device_codeformer)
|
57 |
+
return self.net, self.face_helper
|
58 |
+
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
|
59 |
+
if len(model_paths) != 0:
|
60 |
+
ckpt_path = model_paths[0]
|
61 |
+
else:
|
62 |
+
print("Unable to load codeformer model.")
|
63 |
+
return None, None
|
64 |
+
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
65 |
+
checkpoint = torch.load(ckpt_path)['params_ema']
|
66 |
+
net.load_state_dict(checkpoint)
|
67 |
+
net.eval()
|
68 |
+
|
69 |
+
if hasattr(retinaface, 'device'):
|
70 |
+
retinaface.device = devices.device_codeformer
|
71 |
+
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
|
72 |
+
|
73 |
+
self.net = net
|
74 |
+
self.face_helper = face_helper
|
75 |
+
|
76 |
+
return net, face_helper
|
77 |
+
|
78 |
+
def send_model_to(self, device):
|
79 |
+
self.net.to(device)
|
80 |
+
self.face_helper.face_det.to(device)
|
81 |
+
self.face_helper.face_parse.to(device)
|
82 |
+
|
83 |
+
def restore(self, np_image, w=None):
|
84 |
+
np_image = np_image[:, :, ::-1]
|
85 |
+
|
86 |
+
original_resolution = np_image.shape[0:2]
|
87 |
+
|
88 |
+
self.create_models()
|
89 |
+
if self.net is None or self.face_helper is None:
|
90 |
+
return np_image
|
91 |
+
|
92 |
+
self.send_model_to(devices.device_codeformer)
|
93 |
+
|
94 |
+
self.face_helper.clean_all()
|
95 |
+
self.face_helper.read_image(np_image)
|
96 |
+
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
97 |
+
self.face_helper.align_warp_face()
|
98 |
+
|
99 |
+
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
100 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
101 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
102 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
103 |
+
|
104 |
+
try:
|
105 |
+
with torch.no_grad():
|
106 |
+
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
107 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
108 |
+
del output
|
109 |
+
torch.cuda.empty_cache()
|
110 |
+
except Exception as error:
|
111 |
+
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
|
112 |
+
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
113 |
+
|
114 |
+
restored_face = restored_face.astype('uint8')
|
115 |
+
self.face_helper.add_restored_face(restored_face)
|
116 |
+
|
117 |
+
self.face_helper.get_inverse_affine(None)
|
118 |
+
|
119 |
+
restored_img = self.face_helper.paste_faces_to_input_image()
|
120 |
+
restored_img = restored_img[:, :, ::-1]
|
121 |
+
|
122 |
+
if original_resolution != restored_img.shape[0:2]:
|
123 |
+
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
124 |
+
|
125 |
+
self.face_helper.clean_all()
|
126 |
+
|
127 |
+
if shared.opts.face_restoration_unload:
|
128 |
+
self.send_model_to(devices.cpu)
|
129 |
+
|
130 |
+
return restored_img
|
131 |
+
|
132 |
+
global have_codeformer
|
133 |
+
have_codeformer = True
|
134 |
+
|
135 |
+
global codeformer
|
136 |
+
codeformer = FaceRestorerCodeFormer(dirname)
|
137 |
+
shared.face_restorers.append(codeformer)
|
138 |
+
|
139 |
+
except Exception:
|
140 |
+
print("Error setting up CodeFormer:", file=sys.stderr)
|
141 |
+
print(traceback.format_exc(), file=sys.stderr)
|
142 |
+
|
143 |
+
# sys.path = stored_sys_path
|
sd/stable-diffusion-webui/modules/deepbooru.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
9 |
+
|
10 |
+
re_special = re.compile(r'([\\()])')
|
11 |
+
|
12 |
+
|
13 |
+
class DeepDanbooru:
|
14 |
+
def __init__(self):
|
15 |
+
self.model = None
|
16 |
+
|
17 |
+
def load(self):
|
18 |
+
if self.model is not None:
|
19 |
+
return
|
20 |
+
|
21 |
+
files = modelloader.load_models(
|
22 |
+
model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
|
23 |
+
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
|
24 |
+
ext_filter=[".pt"],
|
25 |
+
download_name='model-resnet_custom_v3.pt',
|
26 |
+
)
|
27 |
+
|
28 |
+
self.model = deepbooru_model.DeepDanbooruModel()
|
29 |
+
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
|
30 |
+
|
31 |
+
self.model.eval()
|
32 |
+
self.model.to(devices.cpu, devices.dtype)
|
33 |
+
|
34 |
+
def start(self):
|
35 |
+
self.load()
|
36 |
+
self.model.to(devices.device)
|
37 |
+
|
38 |
+
def stop(self):
|
39 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
40 |
+
self.model.to(devices.cpu)
|
41 |
+
devices.torch_gc()
|
42 |
+
|
43 |
+
def tag(self, pil_image):
|
44 |
+
self.start()
|
45 |
+
res = self.tag_multi(pil_image)
|
46 |
+
self.stop()
|
47 |
+
|
48 |
+
return res
|
49 |
+
|
50 |
+
def tag_multi(self, pil_image, force_disable_ranks=False):
|
51 |
+
threshold = shared.opts.interrogate_deepbooru_score_threshold
|
52 |
+
use_spaces = shared.opts.deepbooru_use_spaces
|
53 |
+
use_escape = shared.opts.deepbooru_escape
|
54 |
+
alpha_sort = shared.opts.deepbooru_sort_alpha
|
55 |
+
include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
|
56 |
+
|
57 |
+
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
58 |
+
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
59 |
+
|
60 |
+
with torch.no_grad(), devices.autocast():
|
61 |
+
x = torch.from_numpy(a).to(devices.device)
|
62 |
+
y = self.model(x)[0].detach().cpu().numpy()
|
63 |
+
|
64 |
+
probability_dict = {}
|
65 |
+
|
66 |
+
for tag, probability in zip(self.model.tags, y):
|
67 |
+
if probability < threshold:
|
68 |
+
continue
|
69 |
+
|
70 |
+
if tag.startswith("rating:"):
|
71 |
+
continue
|
72 |
+
|
73 |
+
probability_dict[tag] = probability
|
74 |
+
|
75 |
+
if alpha_sort:
|
76 |
+
tags = sorted(probability_dict)
|
77 |
+
else:
|
78 |
+
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
|
79 |
+
|
80 |
+
res = []
|
81 |
+
|
82 |
+
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
|
83 |
+
|
84 |
+
for tag in [x for x in tags if x not in filtertags]:
|
85 |
+
probability = probability_dict[tag]
|
86 |
+
tag_outformat = tag
|
87 |
+
if use_spaces:
|
88 |
+
tag_outformat = tag_outformat.replace('_', ' ')
|
89 |
+
if use_escape:
|
90 |
+
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
91 |
+
if include_ranks:
|
92 |
+
tag_outformat = f"({tag_outformat}:{probability:.3f})"
|
93 |
+
|
94 |
+
res.append(tag_outformat)
|
95 |
+
|
96 |
+
return ", ".join(res)
|
97 |
+
|
98 |
+
|
99 |
+
model = DeepDanbooru()
|
sd/stable-diffusion-webui/modules/deepbooru_model.py
ADDED
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from modules import devices
|
6 |
+
|
7 |
+
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
8 |
+
|
9 |
+
|
10 |
+
class DeepDanbooruModel(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super(DeepDanbooruModel, self).__init__()
|
13 |
+
|
14 |
+
self.tags = []
|
15 |
+
|
16 |
+
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
|
17 |
+
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
|
18 |
+
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
19 |
+
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
|
20 |
+
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
21 |
+
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
22 |
+
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
23 |
+
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
24 |
+
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
25 |
+
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
26 |
+
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
27 |
+
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
28 |
+
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
|
29 |
+
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
|
30 |
+
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
|
31 |
+
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
32 |
+
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
33 |
+
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
34 |
+
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
35 |
+
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
36 |
+
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
37 |
+
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
38 |
+
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
39 |
+
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
40 |
+
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
41 |
+
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
42 |
+
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
43 |
+
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
44 |
+
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
45 |
+
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
46 |
+
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
47 |
+
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
48 |
+
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
49 |
+
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
50 |
+
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
51 |
+
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
52 |
+
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
53 |
+
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
|
54 |
+
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
|
55 |
+
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
56 |
+
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
57 |
+
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
58 |
+
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
59 |
+
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
60 |
+
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
61 |
+
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
62 |
+
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
63 |
+
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
64 |
+
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
65 |
+
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
66 |
+
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
67 |
+
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
68 |
+
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
69 |
+
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
70 |
+
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
71 |
+
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
72 |
+
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
73 |
+
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
74 |
+
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
75 |
+
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
76 |
+
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
77 |
+
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
78 |
+
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
79 |
+
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
80 |
+
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
81 |
+
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
82 |
+
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
83 |
+
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
84 |
+
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
85 |
+
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
86 |
+
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
87 |
+
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
88 |
+
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
89 |
+
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
90 |
+
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
91 |
+
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
92 |
+
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
93 |
+
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
94 |
+
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
95 |
+
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
96 |
+
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
97 |
+
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
98 |
+
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
99 |
+
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
100 |
+
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
101 |
+
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
102 |
+
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
103 |
+
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
104 |
+
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
105 |
+
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
106 |
+
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
107 |
+
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
108 |
+
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
109 |
+
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
110 |
+
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
111 |
+
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
112 |
+
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
113 |
+
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
114 |
+
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
115 |
+
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
116 |
+
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
117 |
+
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
|
118 |
+
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
119 |
+
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
120 |
+
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
121 |
+
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
122 |
+
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
123 |
+
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
124 |
+
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
125 |
+
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
126 |
+
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
127 |
+
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
128 |
+
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
129 |
+
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
130 |
+
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
131 |
+
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
132 |
+
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
133 |
+
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
134 |
+
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
135 |
+
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
136 |
+
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
137 |
+
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
138 |
+
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
139 |
+
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
140 |
+
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
141 |
+
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
142 |
+
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
143 |
+
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
144 |
+
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
145 |
+
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
146 |
+
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
147 |
+
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
148 |
+
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
149 |
+
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
150 |
+
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
151 |
+
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
152 |
+
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
153 |
+
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
154 |
+
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
155 |
+
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
156 |
+
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
157 |
+
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
158 |
+
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
159 |
+
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
160 |
+
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
161 |
+
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
162 |
+
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
163 |
+
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
164 |
+
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
165 |
+
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
166 |
+
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
167 |
+
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
168 |
+
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
169 |
+
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
170 |
+
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
171 |
+
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
172 |
+
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
173 |
+
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
174 |
+
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
175 |
+
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
|
176 |
+
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
|
177 |
+
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
|
178 |
+
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
179 |
+
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
180 |
+
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
181 |
+
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
182 |
+
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
183 |
+
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
184 |
+
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
185 |
+
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
|
186 |
+
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
|
187 |
+
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
|
188 |
+
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
189 |
+
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
190 |
+
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
191 |
+
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
192 |
+
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
193 |
+
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
194 |
+
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
195 |
+
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
|
196 |
+
|
197 |
+
def forward(self, *inputs):
|
198 |
+
t_358, = inputs
|
199 |
+
t_359 = t_358.permute(*[0, 3, 1, 2])
|
200 |
+
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
201 |
+
t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
|
202 |
+
t_361 = F.relu(t_360)
|
203 |
+
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
204 |
+
t_362 = self.n_MaxPool_0(t_361)
|
205 |
+
t_363 = self.n_Conv_1(t_362)
|
206 |
+
t_364 = self.n_Conv_2(t_362)
|
207 |
+
t_365 = F.relu(t_364)
|
208 |
+
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
|
209 |
+
t_366 = self.n_Conv_3(t_365_padded)
|
210 |
+
t_367 = F.relu(t_366)
|
211 |
+
t_368 = self.n_Conv_4(t_367)
|
212 |
+
t_369 = torch.add(t_368, t_363)
|
213 |
+
t_370 = F.relu(t_369)
|
214 |
+
t_371 = self.n_Conv_5(t_370)
|
215 |
+
t_372 = F.relu(t_371)
|
216 |
+
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
|
217 |
+
t_373 = self.n_Conv_6(t_372_padded)
|
218 |
+
t_374 = F.relu(t_373)
|
219 |
+
t_375 = self.n_Conv_7(t_374)
|
220 |
+
t_376 = torch.add(t_375, t_370)
|
221 |
+
t_377 = F.relu(t_376)
|
222 |
+
t_378 = self.n_Conv_8(t_377)
|
223 |
+
t_379 = F.relu(t_378)
|
224 |
+
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
|
225 |
+
t_380 = self.n_Conv_9(t_379_padded)
|
226 |
+
t_381 = F.relu(t_380)
|
227 |
+
t_382 = self.n_Conv_10(t_381)
|
228 |
+
t_383 = torch.add(t_382, t_377)
|
229 |
+
t_384 = F.relu(t_383)
|
230 |
+
t_385 = self.n_Conv_11(t_384)
|
231 |
+
t_386 = self.n_Conv_12(t_384)
|
232 |
+
t_387 = F.relu(t_386)
|
233 |
+
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
|
234 |
+
t_388 = self.n_Conv_13(t_387_padded)
|
235 |
+
t_389 = F.relu(t_388)
|
236 |
+
t_390 = self.n_Conv_14(t_389)
|
237 |
+
t_391 = torch.add(t_390, t_385)
|
238 |
+
t_392 = F.relu(t_391)
|
239 |
+
t_393 = self.n_Conv_15(t_392)
|
240 |
+
t_394 = F.relu(t_393)
|
241 |
+
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
|
242 |
+
t_395 = self.n_Conv_16(t_394_padded)
|
243 |
+
t_396 = F.relu(t_395)
|
244 |
+
t_397 = self.n_Conv_17(t_396)
|
245 |
+
t_398 = torch.add(t_397, t_392)
|
246 |
+
t_399 = F.relu(t_398)
|
247 |
+
t_400 = self.n_Conv_18(t_399)
|
248 |
+
t_401 = F.relu(t_400)
|
249 |
+
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
|
250 |
+
t_402 = self.n_Conv_19(t_401_padded)
|
251 |
+
t_403 = F.relu(t_402)
|
252 |
+
t_404 = self.n_Conv_20(t_403)
|
253 |
+
t_405 = torch.add(t_404, t_399)
|
254 |
+
t_406 = F.relu(t_405)
|
255 |
+
t_407 = self.n_Conv_21(t_406)
|
256 |
+
t_408 = F.relu(t_407)
|
257 |
+
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
|
258 |
+
t_409 = self.n_Conv_22(t_408_padded)
|
259 |
+
t_410 = F.relu(t_409)
|
260 |
+
t_411 = self.n_Conv_23(t_410)
|
261 |
+
t_412 = torch.add(t_411, t_406)
|
262 |
+
t_413 = F.relu(t_412)
|
263 |
+
t_414 = self.n_Conv_24(t_413)
|
264 |
+
t_415 = F.relu(t_414)
|
265 |
+
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
|
266 |
+
t_416 = self.n_Conv_25(t_415_padded)
|
267 |
+
t_417 = F.relu(t_416)
|
268 |
+
t_418 = self.n_Conv_26(t_417)
|
269 |
+
t_419 = torch.add(t_418, t_413)
|
270 |
+
t_420 = F.relu(t_419)
|
271 |
+
t_421 = self.n_Conv_27(t_420)
|
272 |
+
t_422 = F.relu(t_421)
|
273 |
+
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
|
274 |
+
t_423 = self.n_Conv_28(t_422_padded)
|
275 |
+
t_424 = F.relu(t_423)
|
276 |
+
t_425 = self.n_Conv_29(t_424)
|
277 |
+
t_426 = torch.add(t_425, t_420)
|
278 |
+
t_427 = F.relu(t_426)
|
279 |
+
t_428 = self.n_Conv_30(t_427)
|
280 |
+
t_429 = F.relu(t_428)
|
281 |
+
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
|
282 |
+
t_430 = self.n_Conv_31(t_429_padded)
|
283 |
+
t_431 = F.relu(t_430)
|
284 |
+
t_432 = self.n_Conv_32(t_431)
|
285 |
+
t_433 = torch.add(t_432, t_427)
|
286 |
+
t_434 = F.relu(t_433)
|
287 |
+
t_435 = self.n_Conv_33(t_434)
|
288 |
+
t_436 = F.relu(t_435)
|
289 |
+
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
|
290 |
+
t_437 = self.n_Conv_34(t_436_padded)
|
291 |
+
t_438 = F.relu(t_437)
|
292 |
+
t_439 = self.n_Conv_35(t_438)
|
293 |
+
t_440 = torch.add(t_439, t_434)
|
294 |
+
t_441 = F.relu(t_440)
|
295 |
+
t_442 = self.n_Conv_36(t_441)
|
296 |
+
t_443 = self.n_Conv_37(t_441)
|
297 |
+
t_444 = F.relu(t_443)
|
298 |
+
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
|
299 |
+
t_445 = self.n_Conv_38(t_444_padded)
|
300 |
+
t_446 = F.relu(t_445)
|
301 |
+
t_447 = self.n_Conv_39(t_446)
|
302 |
+
t_448 = torch.add(t_447, t_442)
|
303 |
+
t_449 = F.relu(t_448)
|
304 |
+
t_450 = self.n_Conv_40(t_449)
|
305 |
+
t_451 = F.relu(t_450)
|
306 |
+
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
|
307 |
+
t_452 = self.n_Conv_41(t_451_padded)
|
308 |
+
t_453 = F.relu(t_452)
|
309 |
+
t_454 = self.n_Conv_42(t_453)
|
310 |
+
t_455 = torch.add(t_454, t_449)
|
311 |
+
t_456 = F.relu(t_455)
|
312 |
+
t_457 = self.n_Conv_43(t_456)
|
313 |
+
t_458 = F.relu(t_457)
|
314 |
+
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
|
315 |
+
t_459 = self.n_Conv_44(t_458_padded)
|
316 |
+
t_460 = F.relu(t_459)
|
317 |
+
t_461 = self.n_Conv_45(t_460)
|
318 |
+
t_462 = torch.add(t_461, t_456)
|
319 |
+
t_463 = F.relu(t_462)
|
320 |
+
t_464 = self.n_Conv_46(t_463)
|
321 |
+
t_465 = F.relu(t_464)
|
322 |
+
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
|
323 |
+
t_466 = self.n_Conv_47(t_465_padded)
|
324 |
+
t_467 = F.relu(t_466)
|
325 |
+
t_468 = self.n_Conv_48(t_467)
|
326 |
+
t_469 = torch.add(t_468, t_463)
|
327 |
+
t_470 = F.relu(t_469)
|
328 |
+
t_471 = self.n_Conv_49(t_470)
|
329 |
+
t_472 = F.relu(t_471)
|
330 |
+
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
|
331 |
+
t_473 = self.n_Conv_50(t_472_padded)
|
332 |
+
t_474 = F.relu(t_473)
|
333 |
+
t_475 = self.n_Conv_51(t_474)
|
334 |
+
t_476 = torch.add(t_475, t_470)
|
335 |
+
t_477 = F.relu(t_476)
|
336 |
+
t_478 = self.n_Conv_52(t_477)
|
337 |
+
t_479 = F.relu(t_478)
|
338 |
+
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
|
339 |
+
t_480 = self.n_Conv_53(t_479_padded)
|
340 |
+
t_481 = F.relu(t_480)
|
341 |
+
t_482 = self.n_Conv_54(t_481)
|
342 |
+
t_483 = torch.add(t_482, t_477)
|
343 |
+
t_484 = F.relu(t_483)
|
344 |
+
t_485 = self.n_Conv_55(t_484)
|
345 |
+
t_486 = F.relu(t_485)
|
346 |
+
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
|
347 |
+
t_487 = self.n_Conv_56(t_486_padded)
|
348 |
+
t_488 = F.relu(t_487)
|
349 |
+
t_489 = self.n_Conv_57(t_488)
|
350 |
+
t_490 = torch.add(t_489, t_484)
|
351 |
+
t_491 = F.relu(t_490)
|
352 |
+
t_492 = self.n_Conv_58(t_491)
|
353 |
+
t_493 = F.relu(t_492)
|
354 |
+
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
|
355 |
+
t_494 = self.n_Conv_59(t_493_padded)
|
356 |
+
t_495 = F.relu(t_494)
|
357 |
+
t_496 = self.n_Conv_60(t_495)
|
358 |
+
t_497 = torch.add(t_496, t_491)
|
359 |
+
t_498 = F.relu(t_497)
|
360 |
+
t_499 = self.n_Conv_61(t_498)
|
361 |
+
t_500 = F.relu(t_499)
|
362 |
+
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
|
363 |
+
t_501 = self.n_Conv_62(t_500_padded)
|
364 |
+
t_502 = F.relu(t_501)
|
365 |
+
t_503 = self.n_Conv_63(t_502)
|
366 |
+
t_504 = torch.add(t_503, t_498)
|
367 |
+
t_505 = F.relu(t_504)
|
368 |
+
t_506 = self.n_Conv_64(t_505)
|
369 |
+
t_507 = F.relu(t_506)
|
370 |
+
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
|
371 |
+
t_508 = self.n_Conv_65(t_507_padded)
|
372 |
+
t_509 = F.relu(t_508)
|
373 |
+
t_510 = self.n_Conv_66(t_509)
|
374 |
+
t_511 = torch.add(t_510, t_505)
|
375 |
+
t_512 = F.relu(t_511)
|
376 |
+
t_513 = self.n_Conv_67(t_512)
|
377 |
+
t_514 = F.relu(t_513)
|
378 |
+
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
|
379 |
+
t_515 = self.n_Conv_68(t_514_padded)
|
380 |
+
t_516 = F.relu(t_515)
|
381 |
+
t_517 = self.n_Conv_69(t_516)
|
382 |
+
t_518 = torch.add(t_517, t_512)
|
383 |
+
t_519 = F.relu(t_518)
|
384 |
+
t_520 = self.n_Conv_70(t_519)
|
385 |
+
t_521 = F.relu(t_520)
|
386 |
+
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
|
387 |
+
t_522 = self.n_Conv_71(t_521_padded)
|
388 |
+
t_523 = F.relu(t_522)
|
389 |
+
t_524 = self.n_Conv_72(t_523)
|
390 |
+
t_525 = torch.add(t_524, t_519)
|
391 |
+
t_526 = F.relu(t_525)
|
392 |
+
t_527 = self.n_Conv_73(t_526)
|
393 |
+
t_528 = F.relu(t_527)
|
394 |
+
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
|
395 |
+
t_529 = self.n_Conv_74(t_528_padded)
|
396 |
+
t_530 = F.relu(t_529)
|
397 |
+
t_531 = self.n_Conv_75(t_530)
|
398 |
+
t_532 = torch.add(t_531, t_526)
|
399 |
+
t_533 = F.relu(t_532)
|
400 |
+
t_534 = self.n_Conv_76(t_533)
|
401 |
+
t_535 = F.relu(t_534)
|
402 |
+
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
|
403 |
+
t_536 = self.n_Conv_77(t_535_padded)
|
404 |
+
t_537 = F.relu(t_536)
|
405 |
+
t_538 = self.n_Conv_78(t_537)
|
406 |
+
t_539 = torch.add(t_538, t_533)
|
407 |
+
t_540 = F.relu(t_539)
|
408 |
+
t_541 = self.n_Conv_79(t_540)
|
409 |
+
t_542 = F.relu(t_541)
|
410 |
+
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
|
411 |
+
t_543 = self.n_Conv_80(t_542_padded)
|
412 |
+
t_544 = F.relu(t_543)
|
413 |
+
t_545 = self.n_Conv_81(t_544)
|
414 |
+
t_546 = torch.add(t_545, t_540)
|
415 |
+
t_547 = F.relu(t_546)
|
416 |
+
t_548 = self.n_Conv_82(t_547)
|
417 |
+
t_549 = F.relu(t_548)
|
418 |
+
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
|
419 |
+
t_550 = self.n_Conv_83(t_549_padded)
|
420 |
+
t_551 = F.relu(t_550)
|
421 |
+
t_552 = self.n_Conv_84(t_551)
|
422 |
+
t_553 = torch.add(t_552, t_547)
|
423 |
+
t_554 = F.relu(t_553)
|
424 |
+
t_555 = self.n_Conv_85(t_554)
|
425 |
+
t_556 = F.relu(t_555)
|
426 |
+
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
|
427 |
+
t_557 = self.n_Conv_86(t_556_padded)
|
428 |
+
t_558 = F.relu(t_557)
|
429 |
+
t_559 = self.n_Conv_87(t_558)
|
430 |
+
t_560 = torch.add(t_559, t_554)
|
431 |
+
t_561 = F.relu(t_560)
|
432 |
+
t_562 = self.n_Conv_88(t_561)
|
433 |
+
t_563 = F.relu(t_562)
|
434 |
+
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
|
435 |
+
t_564 = self.n_Conv_89(t_563_padded)
|
436 |
+
t_565 = F.relu(t_564)
|
437 |
+
t_566 = self.n_Conv_90(t_565)
|
438 |
+
t_567 = torch.add(t_566, t_561)
|
439 |
+
t_568 = F.relu(t_567)
|
440 |
+
t_569 = self.n_Conv_91(t_568)
|
441 |
+
t_570 = F.relu(t_569)
|
442 |
+
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
|
443 |
+
t_571 = self.n_Conv_92(t_570_padded)
|
444 |
+
t_572 = F.relu(t_571)
|
445 |
+
t_573 = self.n_Conv_93(t_572)
|
446 |
+
t_574 = torch.add(t_573, t_568)
|
447 |
+
t_575 = F.relu(t_574)
|
448 |
+
t_576 = self.n_Conv_94(t_575)
|
449 |
+
t_577 = F.relu(t_576)
|
450 |
+
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
|
451 |
+
t_578 = self.n_Conv_95(t_577_padded)
|
452 |
+
t_579 = F.relu(t_578)
|
453 |
+
t_580 = self.n_Conv_96(t_579)
|
454 |
+
t_581 = torch.add(t_580, t_575)
|
455 |
+
t_582 = F.relu(t_581)
|
456 |
+
t_583 = self.n_Conv_97(t_582)
|
457 |
+
t_584 = F.relu(t_583)
|
458 |
+
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
|
459 |
+
t_585 = self.n_Conv_98(t_584_padded)
|
460 |
+
t_586 = F.relu(t_585)
|
461 |
+
t_587 = self.n_Conv_99(t_586)
|
462 |
+
t_588 = self.n_Conv_100(t_582)
|
463 |
+
t_589 = torch.add(t_587, t_588)
|
464 |
+
t_590 = F.relu(t_589)
|
465 |
+
t_591 = self.n_Conv_101(t_590)
|
466 |
+
t_592 = F.relu(t_591)
|
467 |
+
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
|
468 |
+
t_593 = self.n_Conv_102(t_592_padded)
|
469 |
+
t_594 = F.relu(t_593)
|
470 |
+
t_595 = self.n_Conv_103(t_594)
|
471 |
+
t_596 = torch.add(t_595, t_590)
|
472 |
+
t_597 = F.relu(t_596)
|
473 |
+
t_598 = self.n_Conv_104(t_597)
|
474 |
+
t_599 = F.relu(t_598)
|
475 |
+
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
|
476 |
+
t_600 = self.n_Conv_105(t_599_padded)
|
477 |
+
t_601 = F.relu(t_600)
|
478 |
+
t_602 = self.n_Conv_106(t_601)
|
479 |
+
t_603 = torch.add(t_602, t_597)
|
480 |
+
t_604 = F.relu(t_603)
|
481 |
+
t_605 = self.n_Conv_107(t_604)
|
482 |
+
t_606 = F.relu(t_605)
|
483 |
+
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
|
484 |
+
t_607 = self.n_Conv_108(t_606_padded)
|
485 |
+
t_608 = F.relu(t_607)
|
486 |
+
t_609 = self.n_Conv_109(t_608)
|
487 |
+
t_610 = torch.add(t_609, t_604)
|
488 |
+
t_611 = F.relu(t_610)
|
489 |
+
t_612 = self.n_Conv_110(t_611)
|
490 |
+
t_613 = F.relu(t_612)
|
491 |
+
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
|
492 |
+
t_614 = self.n_Conv_111(t_613_padded)
|
493 |
+
t_615 = F.relu(t_614)
|
494 |
+
t_616 = self.n_Conv_112(t_615)
|
495 |
+
t_617 = torch.add(t_616, t_611)
|
496 |
+
t_618 = F.relu(t_617)
|
497 |
+
t_619 = self.n_Conv_113(t_618)
|
498 |
+
t_620 = F.relu(t_619)
|
499 |
+
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
|
500 |
+
t_621 = self.n_Conv_114(t_620_padded)
|
501 |
+
t_622 = F.relu(t_621)
|
502 |
+
t_623 = self.n_Conv_115(t_622)
|
503 |
+
t_624 = torch.add(t_623, t_618)
|
504 |
+
t_625 = F.relu(t_624)
|
505 |
+
t_626 = self.n_Conv_116(t_625)
|
506 |
+
t_627 = F.relu(t_626)
|
507 |
+
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
|
508 |
+
t_628 = self.n_Conv_117(t_627_padded)
|
509 |
+
t_629 = F.relu(t_628)
|
510 |
+
t_630 = self.n_Conv_118(t_629)
|
511 |
+
t_631 = torch.add(t_630, t_625)
|
512 |
+
t_632 = F.relu(t_631)
|
513 |
+
t_633 = self.n_Conv_119(t_632)
|
514 |
+
t_634 = F.relu(t_633)
|
515 |
+
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
|
516 |
+
t_635 = self.n_Conv_120(t_634_padded)
|
517 |
+
t_636 = F.relu(t_635)
|
518 |
+
t_637 = self.n_Conv_121(t_636)
|
519 |
+
t_638 = torch.add(t_637, t_632)
|
520 |
+
t_639 = F.relu(t_638)
|
521 |
+
t_640 = self.n_Conv_122(t_639)
|
522 |
+
t_641 = F.relu(t_640)
|
523 |
+
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
|
524 |
+
t_642 = self.n_Conv_123(t_641_padded)
|
525 |
+
t_643 = F.relu(t_642)
|
526 |
+
t_644 = self.n_Conv_124(t_643)
|
527 |
+
t_645 = torch.add(t_644, t_639)
|
528 |
+
t_646 = F.relu(t_645)
|
529 |
+
t_647 = self.n_Conv_125(t_646)
|
530 |
+
t_648 = F.relu(t_647)
|
531 |
+
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
|
532 |
+
t_649 = self.n_Conv_126(t_648_padded)
|
533 |
+
t_650 = F.relu(t_649)
|
534 |
+
t_651 = self.n_Conv_127(t_650)
|
535 |
+
t_652 = torch.add(t_651, t_646)
|
536 |
+
t_653 = F.relu(t_652)
|
537 |
+
t_654 = self.n_Conv_128(t_653)
|
538 |
+
t_655 = F.relu(t_654)
|
539 |
+
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
|
540 |
+
t_656 = self.n_Conv_129(t_655_padded)
|
541 |
+
t_657 = F.relu(t_656)
|
542 |
+
t_658 = self.n_Conv_130(t_657)
|
543 |
+
t_659 = torch.add(t_658, t_653)
|
544 |
+
t_660 = F.relu(t_659)
|
545 |
+
t_661 = self.n_Conv_131(t_660)
|
546 |
+
t_662 = F.relu(t_661)
|
547 |
+
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
|
548 |
+
t_663 = self.n_Conv_132(t_662_padded)
|
549 |
+
t_664 = F.relu(t_663)
|
550 |
+
t_665 = self.n_Conv_133(t_664)
|
551 |
+
t_666 = torch.add(t_665, t_660)
|
552 |
+
t_667 = F.relu(t_666)
|
553 |
+
t_668 = self.n_Conv_134(t_667)
|
554 |
+
t_669 = F.relu(t_668)
|
555 |
+
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
|
556 |
+
t_670 = self.n_Conv_135(t_669_padded)
|
557 |
+
t_671 = F.relu(t_670)
|
558 |
+
t_672 = self.n_Conv_136(t_671)
|
559 |
+
t_673 = torch.add(t_672, t_667)
|
560 |
+
t_674 = F.relu(t_673)
|
561 |
+
t_675 = self.n_Conv_137(t_674)
|
562 |
+
t_676 = F.relu(t_675)
|
563 |
+
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
|
564 |
+
t_677 = self.n_Conv_138(t_676_padded)
|
565 |
+
t_678 = F.relu(t_677)
|
566 |
+
t_679 = self.n_Conv_139(t_678)
|
567 |
+
t_680 = torch.add(t_679, t_674)
|
568 |
+
t_681 = F.relu(t_680)
|
569 |
+
t_682 = self.n_Conv_140(t_681)
|
570 |
+
t_683 = F.relu(t_682)
|
571 |
+
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
|
572 |
+
t_684 = self.n_Conv_141(t_683_padded)
|
573 |
+
t_685 = F.relu(t_684)
|
574 |
+
t_686 = self.n_Conv_142(t_685)
|
575 |
+
t_687 = torch.add(t_686, t_681)
|
576 |
+
t_688 = F.relu(t_687)
|
577 |
+
t_689 = self.n_Conv_143(t_688)
|
578 |
+
t_690 = F.relu(t_689)
|
579 |
+
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
|
580 |
+
t_691 = self.n_Conv_144(t_690_padded)
|
581 |
+
t_692 = F.relu(t_691)
|
582 |
+
t_693 = self.n_Conv_145(t_692)
|
583 |
+
t_694 = torch.add(t_693, t_688)
|
584 |
+
t_695 = F.relu(t_694)
|
585 |
+
t_696 = self.n_Conv_146(t_695)
|
586 |
+
t_697 = F.relu(t_696)
|
587 |
+
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
|
588 |
+
t_698 = self.n_Conv_147(t_697_padded)
|
589 |
+
t_699 = F.relu(t_698)
|
590 |
+
t_700 = self.n_Conv_148(t_699)
|
591 |
+
t_701 = torch.add(t_700, t_695)
|
592 |
+
t_702 = F.relu(t_701)
|
593 |
+
t_703 = self.n_Conv_149(t_702)
|
594 |
+
t_704 = F.relu(t_703)
|
595 |
+
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
|
596 |
+
t_705 = self.n_Conv_150(t_704_padded)
|
597 |
+
t_706 = F.relu(t_705)
|
598 |
+
t_707 = self.n_Conv_151(t_706)
|
599 |
+
t_708 = torch.add(t_707, t_702)
|
600 |
+
t_709 = F.relu(t_708)
|
601 |
+
t_710 = self.n_Conv_152(t_709)
|
602 |
+
t_711 = F.relu(t_710)
|
603 |
+
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
|
604 |
+
t_712 = self.n_Conv_153(t_711_padded)
|
605 |
+
t_713 = F.relu(t_712)
|
606 |
+
t_714 = self.n_Conv_154(t_713)
|
607 |
+
t_715 = torch.add(t_714, t_709)
|
608 |
+
t_716 = F.relu(t_715)
|
609 |
+
t_717 = self.n_Conv_155(t_716)
|
610 |
+
t_718 = F.relu(t_717)
|
611 |
+
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
|
612 |
+
t_719 = self.n_Conv_156(t_718_padded)
|
613 |
+
t_720 = F.relu(t_719)
|
614 |
+
t_721 = self.n_Conv_157(t_720)
|
615 |
+
t_722 = torch.add(t_721, t_716)
|
616 |
+
t_723 = F.relu(t_722)
|
617 |
+
t_724 = self.n_Conv_158(t_723)
|
618 |
+
t_725 = self.n_Conv_159(t_723)
|
619 |
+
t_726 = F.relu(t_725)
|
620 |
+
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
|
621 |
+
t_727 = self.n_Conv_160(t_726_padded)
|
622 |
+
t_728 = F.relu(t_727)
|
623 |
+
t_729 = self.n_Conv_161(t_728)
|
624 |
+
t_730 = torch.add(t_729, t_724)
|
625 |
+
t_731 = F.relu(t_730)
|
626 |
+
t_732 = self.n_Conv_162(t_731)
|
627 |
+
t_733 = F.relu(t_732)
|
628 |
+
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
|
629 |
+
t_734 = self.n_Conv_163(t_733_padded)
|
630 |
+
t_735 = F.relu(t_734)
|
631 |
+
t_736 = self.n_Conv_164(t_735)
|
632 |
+
t_737 = torch.add(t_736, t_731)
|
633 |
+
t_738 = F.relu(t_737)
|
634 |
+
t_739 = self.n_Conv_165(t_738)
|
635 |
+
t_740 = F.relu(t_739)
|
636 |
+
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
|
637 |
+
t_741 = self.n_Conv_166(t_740_padded)
|
638 |
+
t_742 = F.relu(t_741)
|
639 |
+
t_743 = self.n_Conv_167(t_742)
|
640 |
+
t_744 = torch.add(t_743, t_738)
|
641 |
+
t_745 = F.relu(t_744)
|
642 |
+
t_746 = self.n_Conv_168(t_745)
|
643 |
+
t_747 = self.n_Conv_169(t_745)
|
644 |
+
t_748 = F.relu(t_747)
|
645 |
+
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
|
646 |
+
t_749 = self.n_Conv_170(t_748_padded)
|
647 |
+
t_750 = F.relu(t_749)
|
648 |
+
t_751 = self.n_Conv_171(t_750)
|
649 |
+
t_752 = torch.add(t_751, t_746)
|
650 |
+
t_753 = F.relu(t_752)
|
651 |
+
t_754 = self.n_Conv_172(t_753)
|
652 |
+
t_755 = F.relu(t_754)
|
653 |
+
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
|
654 |
+
t_756 = self.n_Conv_173(t_755_padded)
|
655 |
+
t_757 = F.relu(t_756)
|
656 |
+
t_758 = self.n_Conv_174(t_757)
|
657 |
+
t_759 = torch.add(t_758, t_753)
|
658 |
+
t_760 = F.relu(t_759)
|
659 |
+
t_761 = self.n_Conv_175(t_760)
|
660 |
+
t_762 = F.relu(t_761)
|
661 |
+
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
|
662 |
+
t_763 = self.n_Conv_176(t_762_padded)
|
663 |
+
t_764 = F.relu(t_763)
|
664 |
+
t_765 = self.n_Conv_177(t_764)
|
665 |
+
t_766 = torch.add(t_765, t_760)
|
666 |
+
t_767 = F.relu(t_766)
|
667 |
+
t_768 = self.n_Conv_178(t_767)
|
668 |
+
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
|
669 |
+
t_770 = torch.squeeze(t_769, 3)
|
670 |
+
t_770 = torch.squeeze(t_770, 2)
|
671 |
+
t_771 = torch.sigmoid(t_770)
|
672 |
+
return t_771
|
673 |
+
|
674 |
+
def load_state_dict(self, state_dict, **kwargs):
|
675 |
+
self.tags = state_dict.get('tags', [])
|
676 |
+
|
677 |
+
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
|
678 |
+
|
sd/stable-diffusion-webui/modules/devices.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import contextlib
|
3 |
+
import torch
|
4 |
+
from modules import errors
|
5 |
+
|
6 |
+
if sys.platform == "darwin":
|
7 |
+
from modules import mac_specific
|
8 |
+
|
9 |
+
|
10 |
+
def has_mps() -> bool:
|
11 |
+
if sys.platform != "darwin":
|
12 |
+
return False
|
13 |
+
else:
|
14 |
+
return mac_specific.has_mps
|
15 |
+
|
16 |
+
def extract_device_id(args, name):
|
17 |
+
for x in range(len(args)):
|
18 |
+
if name in args[x]:
|
19 |
+
return args[x + 1]
|
20 |
+
|
21 |
+
return None
|
22 |
+
|
23 |
+
|
24 |
+
def get_cuda_device_string():
|
25 |
+
from modules import shared
|
26 |
+
|
27 |
+
if shared.cmd_opts.device_id is not None:
|
28 |
+
return f"cuda:{shared.cmd_opts.device_id}"
|
29 |
+
|
30 |
+
return "cuda"
|
31 |
+
|
32 |
+
|
33 |
+
def get_optimal_device_name():
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
return get_cuda_device_string()
|
36 |
+
|
37 |
+
if has_mps():
|
38 |
+
return "mps"
|
39 |
+
|
40 |
+
return "cpu"
|
41 |
+
|
42 |
+
|
43 |
+
def get_optimal_device():
|
44 |
+
return torch.device(get_optimal_device_name())
|
45 |
+
|
46 |
+
|
47 |
+
def get_device_for(task):
|
48 |
+
from modules import shared
|
49 |
+
|
50 |
+
if task in shared.cmd_opts.use_cpu:
|
51 |
+
return cpu
|
52 |
+
|
53 |
+
return get_optimal_device()
|
54 |
+
|
55 |
+
|
56 |
+
def torch_gc():
|
57 |
+
if torch.cuda.is_available():
|
58 |
+
with torch.cuda.device(get_cuda_device_string()):
|
59 |
+
torch.cuda.empty_cache()
|
60 |
+
torch.cuda.ipc_collect()
|
61 |
+
|
62 |
+
|
63 |
+
def enable_tf32():
|
64 |
+
if torch.cuda.is_available():
|
65 |
+
|
66 |
+
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
67 |
+
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
68 |
+
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
|
69 |
+
torch.backends.cudnn.benchmark = True
|
70 |
+
|
71 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
72 |
+
torch.backends.cudnn.allow_tf32 = True
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
errors.run(enable_tf32, "Enabling TF32")
|
77 |
+
|
78 |
+
cpu = torch.device("cpu")
|
79 |
+
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
80 |
+
dtype = torch.float16
|
81 |
+
dtype_vae = torch.float16
|
82 |
+
dtype_unet = torch.float16
|
83 |
+
unet_needs_upcast = False
|
84 |
+
|
85 |
+
|
86 |
+
def cond_cast_unet(input):
|
87 |
+
return input.to(dtype_unet) if unet_needs_upcast else input
|
88 |
+
|
89 |
+
|
90 |
+
def cond_cast_float(input):
|
91 |
+
return input.float() if unet_needs_upcast else input
|
92 |
+
|
93 |
+
|
94 |
+
def randn(seed, shape):
|
95 |
+
torch.manual_seed(seed)
|
96 |
+
if device.type == 'mps':
|
97 |
+
return torch.randn(shape, device=cpu).to(device)
|
98 |
+
return torch.randn(shape, device=device)
|
99 |
+
|
100 |
+
|
101 |
+
def randn_without_seed(shape):
|
102 |
+
if device.type == 'mps':
|
103 |
+
return torch.randn(shape, device=cpu).to(device)
|
104 |
+
return torch.randn(shape, device=device)
|
105 |
+
|
106 |
+
|
107 |
+
def autocast(disable=False):
|
108 |
+
from modules import shared
|
109 |
+
|
110 |
+
if disable:
|
111 |
+
return contextlib.nullcontext()
|
112 |
+
|
113 |
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
114 |
+
return contextlib.nullcontext()
|
115 |
+
|
116 |
+
return torch.autocast("cuda")
|
117 |
+
|
118 |
+
|
119 |
+
def without_autocast(disable=False):
|
120 |
+
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
121 |
+
|
122 |
+
|
123 |
+
class NansException(Exception):
|
124 |
+
pass
|
125 |
+
|
126 |
+
|
127 |
+
def test_for_nans(x, where):
|
128 |
+
from modules import shared
|
129 |
+
|
130 |
+
if shared.cmd_opts.disable_nan_check:
|
131 |
+
return
|
132 |
+
|
133 |
+
if not torch.all(torch.isnan(x)).item():
|
134 |
+
return
|
135 |
+
|
136 |
+
if where == "unet":
|
137 |
+
message = "A tensor with all NaNs was produced in Unet."
|
138 |
+
|
139 |
+
if not shared.cmd_opts.no_half:
|
140 |
+
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
|
141 |
+
|
142 |
+
elif where == "vae":
|
143 |
+
message = "A tensor with all NaNs was produced in VAE."
|
144 |
+
|
145 |
+
if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
|
146 |
+
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
|
147 |
+
else:
|
148 |
+
message = "A tensor with all NaNs was produced."
|
149 |
+
|
150 |
+
message += " Use --disable-nan-check commandline argument to disable this check."
|
151 |
+
|
152 |
+
raise NansException(message)
|
sd/stable-diffusion-webui/modules/errors.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import traceback
|
3 |
+
|
4 |
+
|
5 |
+
def print_error_explanation(message):
|
6 |
+
lines = message.strip().split("\n")
|
7 |
+
max_len = max([len(x) for x in lines])
|
8 |
+
|
9 |
+
print('=' * max_len, file=sys.stderr)
|
10 |
+
for line in lines:
|
11 |
+
print(line, file=sys.stderr)
|
12 |
+
print('=' * max_len, file=sys.stderr)
|
13 |
+
|
14 |
+
|
15 |
+
def display(e: Exception, task):
|
16 |
+
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
17 |
+
print(traceback.format_exc(), file=sys.stderr)
|
18 |
+
|
19 |
+
message = str(e)
|
20 |
+
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
21 |
+
print_error_explanation("""
|
22 |
+
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
|
23 |
+
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
24 |
+
""")
|
25 |
+
|
26 |
+
|
27 |
+
already_displayed = {}
|
28 |
+
|
29 |
+
|
30 |
+
def display_once(e: Exception, task):
|
31 |
+
if task in already_displayed:
|
32 |
+
return
|
33 |
+
|
34 |
+
display(e, task)
|
35 |
+
|
36 |
+
already_displayed[task] = 1
|
37 |
+
|
38 |
+
|
39 |
+
def run(code, task):
|
40 |
+
try:
|
41 |
+
code()
|
42 |
+
except Exception as e:
|
43 |
+
display(task, e)
|
sd/stable-diffusion-webui/modules/esrgan_model.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from basicsr.utils.download_util import load_file_from_url
|
7 |
+
|
8 |
+
import modules.esrgan_model_arch as arch
|
9 |
+
from modules import shared, modelloader, images, devices
|
10 |
+
from modules.upscaler import Upscaler, UpscalerData
|
11 |
+
from modules.shared import opts
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def mod2normal(state_dict):
|
16 |
+
# this code is copied from https://github.com/victorca25/iNNfer
|
17 |
+
if 'conv_first.weight' in state_dict:
|
18 |
+
crt_net = {}
|
19 |
+
items = []
|
20 |
+
for k, v in state_dict.items():
|
21 |
+
items.append(k)
|
22 |
+
|
23 |
+
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
24 |
+
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
25 |
+
|
26 |
+
for k in items.copy():
|
27 |
+
if 'RDB' in k:
|
28 |
+
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
29 |
+
if '.weight' in k:
|
30 |
+
ori_k = ori_k.replace('.weight', '.0.weight')
|
31 |
+
elif '.bias' in k:
|
32 |
+
ori_k = ori_k.replace('.bias', '.0.bias')
|
33 |
+
crt_net[ori_k] = state_dict[k]
|
34 |
+
items.remove(k)
|
35 |
+
|
36 |
+
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
|
37 |
+
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
|
38 |
+
crt_net['model.3.weight'] = state_dict['upconv1.weight']
|
39 |
+
crt_net['model.3.bias'] = state_dict['upconv1.bias']
|
40 |
+
crt_net['model.6.weight'] = state_dict['upconv2.weight']
|
41 |
+
crt_net['model.6.bias'] = state_dict['upconv2.bias']
|
42 |
+
crt_net['model.8.weight'] = state_dict['HRconv.weight']
|
43 |
+
crt_net['model.8.bias'] = state_dict['HRconv.bias']
|
44 |
+
crt_net['model.10.weight'] = state_dict['conv_last.weight']
|
45 |
+
crt_net['model.10.bias'] = state_dict['conv_last.bias']
|
46 |
+
state_dict = crt_net
|
47 |
+
return state_dict
|
48 |
+
|
49 |
+
|
50 |
+
def resrgan2normal(state_dict, nb=23):
|
51 |
+
# this code is copied from https://github.com/victorca25/iNNfer
|
52 |
+
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
53 |
+
re8x = 0
|
54 |
+
crt_net = {}
|
55 |
+
items = []
|
56 |
+
for k, v in state_dict.items():
|
57 |
+
items.append(k)
|
58 |
+
|
59 |
+
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
60 |
+
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
61 |
+
|
62 |
+
for k in items.copy():
|
63 |
+
if "rdb" in k:
|
64 |
+
ori_k = k.replace('body.', 'model.1.sub.')
|
65 |
+
ori_k = ori_k.replace('.rdb', '.RDB')
|
66 |
+
if '.weight' in k:
|
67 |
+
ori_k = ori_k.replace('.weight', '.0.weight')
|
68 |
+
elif '.bias' in k:
|
69 |
+
ori_k = ori_k.replace('.bias', '.0.bias')
|
70 |
+
crt_net[ori_k] = state_dict[k]
|
71 |
+
items.remove(k)
|
72 |
+
|
73 |
+
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
|
74 |
+
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
|
75 |
+
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
|
76 |
+
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
|
77 |
+
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
|
78 |
+
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
|
79 |
+
|
80 |
+
if 'conv_up3.weight' in state_dict:
|
81 |
+
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
|
82 |
+
re8x = 3
|
83 |
+
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
|
84 |
+
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
|
85 |
+
|
86 |
+
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
|
87 |
+
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
|
88 |
+
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
|
89 |
+
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
|
90 |
+
|
91 |
+
state_dict = crt_net
|
92 |
+
return state_dict
|
93 |
+
|
94 |
+
|
95 |
+
def infer_params(state_dict):
|
96 |
+
# this code is copied from https://github.com/victorca25/iNNfer
|
97 |
+
scale2x = 0
|
98 |
+
scalemin = 6
|
99 |
+
n_uplayer = 0
|
100 |
+
plus = False
|
101 |
+
|
102 |
+
for block in list(state_dict):
|
103 |
+
parts = block.split(".")
|
104 |
+
n_parts = len(parts)
|
105 |
+
if n_parts == 5 and parts[2] == "sub":
|
106 |
+
nb = int(parts[3])
|
107 |
+
elif n_parts == 3:
|
108 |
+
part_num = int(parts[1])
|
109 |
+
if (part_num > scalemin
|
110 |
+
and parts[0] == "model"
|
111 |
+
and parts[2] == "weight"):
|
112 |
+
scale2x += 1
|
113 |
+
if part_num > n_uplayer:
|
114 |
+
n_uplayer = part_num
|
115 |
+
out_nc = state_dict[block].shape[0]
|
116 |
+
if not plus and "conv1x1" in block:
|
117 |
+
plus = True
|
118 |
+
|
119 |
+
nf = state_dict["model.0.weight"].shape[0]
|
120 |
+
in_nc = state_dict["model.0.weight"].shape[1]
|
121 |
+
out_nc = out_nc
|
122 |
+
scale = 2 ** scale2x
|
123 |
+
|
124 |
+
return in_nc, out_nc, nf, nb, plus, scale
|
125 |
+
|
126 |
+
|
127 |
+
class UpscalerESRGAN(Upscaler):
|
128 |
+
def __init__(self, dirname):
|
129 |
+
self.name = "ESRGAN"
|
130 |
+
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
|
131 |
+
self.model_name = "ESRGAN_4x"
|
132 |
+
self.scalers = []
|
133 |
+
self.user_path = dirname
|
134 |
+
super().__init__()
|
135 |
+
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
136 |
+
scalers = []
|
137 |
+
if len(model_paths) == 0:
|
138 |
+
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
139 |
+
scalers.append(scaler_data)
|
140 |
+
for file in model_paths:
|
141 |
+
if "http" in file:
|
142 |
+
name = self.model_name
|
143 |
+
else:
|
144 |
+
name = modelloader.friendly_name(file)
|
145 |
+
|
146 |
+
scaler_data = UpscalerData(name, file, self, 4)
|
147 |
+
self.scalers.append(scaler_data)
|
148 |
+
|
149 |
+
def do_upscale(self, img, selected_model):
|
150 |
+
model = self.load_model(selected_model)
|
151 |
+
if model is None:
|
152 |
+
return img
|
153 |
+
model.to(devices.device_esrgan)
|
154 |
+
img = esrgan_upscale(model, img)
|
155 |
+
return img
|
156 |
+
|
157 |
+
def load_model(self, path: str):
|
158 |
+
if "http" in path:
|
159 |
+
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
160 |
+
file_name="%s.pth" % self.model_name,
|
161 |
+
progress=True)
|
162 |
+
else:
|
163 |
+
filename = path
|
164 |
+
if not os.path.exists(filename) or filename is None:
|
165 |
+
print("Unable to load %s from %s" % (self.model_path, filename))
|
166 |
+
return None
|
167 |
+
|
168 |
+
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
169 |
+
|
170 |
+
if "params_ema" in state_dict:
|
171 |
+
state_dict = state_dict["params_ema"]
|
172 |
+
elif "params" in state_dict:
|
173 |
+
state_dict = state_dict["params"]
|
174 |
+
num_conv = 16 if "realesr-animevideov3" in filename else 32
|
175 |
+
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
|
176 |
+
model.load_state_dict(state_dict)
|
177 |
+
model.eval()
|
178 |
+
return model
|
179 |
+
|
180 |
+
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
|
181 |
+
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
|
182 |
+
state_dict = resrgan2normal(state_dict, nb)
|
183 |
+
elif "conv_first.weight" in state_dict:
|
184 |
+
state_dict = mod2normal(state_dict)
|
185 |
+
elif "model.0.weight" not in state_dict:
|
186 |
+
raise Exception("The file is not a recognized ESRGAN model.")
|
187 |
+
|
188 |
+
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
|
189 |
+
|
190 |
+
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
|
191 |
+
model.load_state_dict(state_dict)
|
192 |
+
model.eval()
|
193 |
+
|
194 |
+
return model
|
195 |
+
|
196 |
+
|
197 |
+
def upscale_without_tiling(model, img):
|
198 |
+
img = np.array(img)
|
199 |
+
img = img[:, :, ::-1]
|
200 |
+
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
201 |
+
img = torch.from_numpy(img).float()
|
202 |
+
img = img.unsqueeze(0).to(devices.device_esrgan)
|
203 |
+
with torch.no_grad():
|
204 |
+
output = model(img)
|
205 |
+
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
206 |
+
output = 255. * np.moveaxis(output, 0, 2)
|
207 |
+
output = output.astype(np.uint8)
|
208 |
+
output = output[:, :, ::-1]
|
209 |
+
return Image.fromarray(output, 'RGB')
|
210 |
+
|
211 |
+
|
212 |
+
def esrgan_upscale(model, img):
|
213 |
+
if opts.ESRGAN_tile == 0:
|
214 |
+
return upscale_without_tiling(model, img)
|
215 |
+
|
216 |
+
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
217 |
+
newtiles = []
|
218 |
+
scale_factor = 1
|
219 |
+
|
220 |
+
for y, h, row in grid.tiles:
|
221 |
+
newrow = []
|
222 |
+
for tiledata in row:
|
223 |
+
x, w, tile = tiledata
|
224 |
+
|
225 |
+
output = upscale_without_tiling(model, tile)
|
226 |
+
scale_factor = output.width // tile.width
|
227 |
+
|
228 |
+
newrow.append([x * scale_factor, w * scale_factor, output])
|
229 |
+
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
230 |
+
|
231 |
+
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
232 |
+
output = images.combine_grid(newgrid)
|
233 |
+
return output
|
sd/stable-diffusion-webui/modules/esrgan_model_arch.py
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this file is adapted from https://github.com/victorca25/iNNfer
|
2 |
+
|
3 |
+
from collections import OrderedDict
|
4 |
+
import math
|
5 |
+
import functools
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
####################
|
12 |
+
# RRDBNet Generator
|
13 |
+
####################
|
14 |
+
|
15 |
+
class RRDBNet(nn.Module):
|
16 |
+
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
17 |
+
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
18 |
+
finalact=None, gaussian_noise=False, plus=False):
|
19 |
+
super(RRDBNet, self).__init__()
|
20 |
+
n_upscale = int(math.log(upscale, 2))
|
21 |
+
if upscale == 3:
|
22 |
+
n_upscale = 1
|
23 |
+
|
24 |
+
self.resrgan_scale = 0
|
25 |
+
if in_nc % 16 == 0:
|
26 |
+
self.resrgan_scale = 1
|
27 |
+
elif in_nc != 4 and in_nc % 4 == 0:
|
28 |
+
self.resrgan_scale = 2
|
29 |
+
|
30 |
+
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
31 |
+
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
32 |
+
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
33 |
+
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
34 |
+
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
35 |
+
|
36 |
+
if upsample_mode == 'upconv':
|
37 |
+
upsample_block = upconv_block
|
38 |
+
elif upsample_mode == 'pixelshuffle':
|
39 |
+
upsample_block = pixelshuffle_block
|
40 |
+
else:
|
41 |
+
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
42 |
+
if upscale == 3:
|
43 |
+
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
44 |
+
else:
|
45 |
+
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
46 |
+
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
47 |
+
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
48 |
+
|
49 |
+
outact = act(finalact) if finalact else None
|
50 |
+
|
51 |
+
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
52 |
+
*upsampler, HR_conv0, HR_conv1, outact)
|
53 |
+
|
54 |
+
def forward(self, x, outm=None):
|
55 |
+
if self.resrgan_scale == 1:
|
56 |
+
feat = pixel_unshuffle(x, scale=4)
|
57 |
+
elif self.resrgan_scale == 2:
|
58 |
+
feat = pixel_unshuffle(x, scale=2)
|
59 |
+
else:
|
60 |
+
feat = x
|
61 |
+
|
62 |
+
return self.model(feat)
|
63 |
+
|
64 |
+
|
65 |
+
class RRDB(nn.Module):
|
66 |
+
"""
|
67 |
+
Residual in Residual Dense Block
|
68 |
+
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
72 |
+
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
73 |
+
spectral_norm=False, gaussian_noise=False, plus=False):
|
74 |
+
super(RRDB, self).__init__()
|
75 |
+
# This is for backwards compatibility with existing models
|
76 |
+
if nr == 3:
|
77 |
+
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
78 |
+
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
79 |
+
gaussian_noise=gaussian_noise, plus=plus)
|
80 |
+
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
81 |
+
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
82 |
+
gaussian_noise=gaussian_noise, plus=plus)
|
83 |
+
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
84 |
+
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
85 |
+
gaussian_noise=gaussian_noise, plus=plus)
|
86 |
+
else:
|
87 |
+
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
88 |
+
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
89 |
+
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
90 |
+
self.RDBs = nn.Sequential(*RDB_list)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
if hasattr(self, 'RDB1'):
|
94 |
+
out = self.RDB1(x)
|
95 |
+
out = self.RDB2(out)
|
96 |
+
out = self.RDB3(out)
|
97 |
+
else:
|
98 |
+
out = self.RDBs(x)
|
99 |
+
return out * 0.2 + x
|
100 |
+
|
101 |
+
|
102 |
+
class ResidualDenseBlock_5C(nn.Module):
|
103 |
+
"""
|
104 |
+
Residual Dense Block
|
105 |
+
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
106 |
+
Modified options that can be used:
|
107 |
+
- "Partial Convolution based Padding" arXiv:1811.11718
|
108 |
+
- "Spectral normalization" arXiv:1802.05957
|
109 |
+
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
110 |
+
{Rakotonirina} and A. {Rasoanaivo}
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
114 |
+
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
115 |
+
spectral_norm=False, gaussian_noise=False, plus=False):
|
116 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
117 |
+
|
118 |
+
self.noise = GaussianNoise() if gaussian_noise else None
|
119 |
+
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
120 |
+
|
121 |
+
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
122 |
+
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
123 |
+
spectral_norm=spectral_norm)
|
124 |
+
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
125 |
+
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
126 |
+
spectral_norm=spectral_norm)
|
127 |
+
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
128 |
+
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
129 |
+
spectral_norm=spectral_norm)
|
130 |
+
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
131 |
+
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
132 |
+
spectral_norm=spectral_norm)
|
133 |
+
if mode == 'CNA':
|
134 |
+
last_act = None
|
135 |
+
else:
|
136 |
+
last_act = act_type
|
137 |
+
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
138 |
+
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
139 |
+
spectral_norm=spectral_norm)
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
x1 = self.conv1(x)
|
143 |
+
x2 = self.conv2(torch.cat((x, x1), 1))
|
144 |
+
if self.conv1x1:
|
145 |
+
x2 = x2 + self.conv1x1(x)
|
146 |
+
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
147 |
+
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
148 |
+
if self.conv1x1:
|
149 |
+
x4 = x4 + x2
|
150 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
151 |
+
if self.noise:
|
152 |
+
return self.noise(x5.mul(0.2) + x)
|
153 |
+
else:
|
154 |
+
return x5 * 0.2 + x
|
155 |
+
|
156 |
+
|
157 |
+
####################
|
158 |
+
# ESRGANplus
|
159 |
+
####################
|
160 |
+
|
161 |
+
class GaussianNoise(nn.Module):
|
162 |
+
def __init__(self, sigma=0.1, is_relative_detach=False):
|
163 |
+
super().__init__()
|
164 |
+
self.sigma = sigma
|
165 |
+
self.is_relative_detach = is_relative_detach
|
166 |
+
self.noise = torch.tensor(0, dtype=torch.float)
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
if self.training and self.sigma != 0:
|
170 |
+
self.noise = self.noise.to(x.device)
|
171 |
+
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
172 |
+
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
173 |
+
x = x + sampled_noise
|
174 |
+
return x
|
175 |
+
|
176 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
177 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
178 |
+
|
179 |
+
|
180 |
+
####################
|
181 |
+
# SRVGGNetCompact
|
182 |
+
####################
|
183 |
+
|
184 |
+
class SRVGGNetCompact(nn.Module):
|
185 |
+
"""A compact VGG-style network structure for super-resolution.
|
186 |
+
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
190 |
+
super(SRVGGNetCompact, self).__init__()
|
191 |
+
self.num_in_ch = num_in_ch
|
192 |
+
self.num_out_ch = num_out_ch
|
193 |
+
self.num_feat = num_feat
|
194 |
+
self.num_conv = num_conv
|
195 |
+
self.upscale = upscale
|
196 |
+
self.act_type = act_type
|
197 |
+
|
198 |
+
self.body = nn.ModuleList()
|
199 |
+
# the first conv
|
200 |
+
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
201 |
+
# the first activation
|
202 |
+
if act_type == 'relu':
|
203 |
+
activation = nn.ReLU(inplace=True)
|
204 |
+
elif act_type == 'prelu':
|
205 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
206 |
+
elif act_type == 'leakyrelu':
|
207 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
208 |
+
self.body.append(activation)
|
209 |
+
|
210 |
+
# the body structure
|
211 |
+
for _ in range(num_conv):
|
212 |
+
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
213 |
+
# activation
|
214 |
+
if act_type == 'relu':
|
215 |
+
activation = nn.ReLU(inplace=True)
|
216 |
+
elif act_type == 'prelu':
|
217 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
218 |
+
elif act_type == 'leakyrelu':
|
219 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
220 |
+
self.body.append(activation)
|
221 |
+
|
222 |
+
# the last conv
|
223 |
+
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
224 |
+
# upsample
|
225 |
+
self.upsampler = nn.PixelShuffle(upscale)
|
226 |
+
|
227 |
+
def forward(self, x):
|
228 |
+
out = x
|
229 |
+
for i in range(0, len(self.body)):
|
230 |
+
out = self.body[i](out)
|
231 |
+
|
232 |
+
out = self.upsampler(out)
|
233 |
+
# add the nearest upsampled image, so that the network learns the residual
|
234 |
+
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
235 |
+
out += base
|
236 |
+
return out
|
237 |
+
|
238 |
+
|
239 |
+
####################
|
240 |
+
# Upsampler
|
241 |
+
####################
|
242 |
+
|
243 |
+
class Upsample(nn.Module):
|
244 |
+
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
245 |
+
The input data is assumed to be of the form
|
246 |
+
`minibatch x channels x [optional depth] x [optional height] x width`.
|
247 |
+
"""
|
248 |
+
|
249 |
+
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
250 |
+
super(Upsample, self).__init__()
|
251 |
+
if isinstance(scale_factor, tuple):
|
252 |
+
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
253 |
+
else:
|
254 |
+
self.scale_factor = float(scale_factor) if scale_factor else None
|
255 |
+
self.mode = mode
|
256 |
+
self.size = size
|
257 |
+
self.align_corners = align_corners
|
258 |
+
|
259 |
+
def forward(self, x):
|
260 |
+
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
261 |
+
|
262 |
+
def extra_repr(self):
|
263 |
+
if self.scale_factor is not None:
|
264 |
+
info = 'scale_factor=' + str(self.scale_factor)
|
265 |
+
else:
|
266 |
+
info = 'size=' + str(self.size)
|
267 |
+
info += ', mode=' + self.mode
|
268 |
+
return info
|
269 |
+
|
270 |
+
|
271 |
+
def pixel_unshuffle(x, scale):
|
272 |
+
""" Pixel unshuffle.
|
273 |
+
Args:
|
274 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
275 |
+
scale (int): Downsample ratio.
|
276 |
+
Returns:
|
277 |
+
Tensor: the pixel unshuffled feature.
|
278 |
+
"""
|
279 |
+
b, c, hh, hw = x.size()
|
280 |
+
out_channel = c * (scale**2)
|
281 |
+
assert hh % scale == 0 and hw % scale == 0
|
282 |
+
h = hh // scale
|
283 |
+
w = hw // scale
|
284 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
285 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
286 |
+
|
287 |
+
|
288 |
+
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
289 |
+
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
290 |
+
"""
|
291 |
+
Pixel shuffle layer
|
292 |
+
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
293 |
+
Neural Network, CVPR17)
|
294 |
+
"""
|
295 |
+
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
296 |
+
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
297 |
+
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
298 |
+
|
299 |
+
n = norm(norm_type, out_nc) if norm_type else None
|
300 |
+
a = act(act_type) if act_type else None
|
301 |
+
return sequential(conv, pixel_shuffle, n, a)
|
302 |
+
|
303 |
+
|
304 |
+
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
305 |
+
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
306 |
+
""" Upconv layer """
|
307 |
+
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
308 |
+
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
309 |
+
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
310 |
+
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
311 |
+
return sequential(upsample, conv)
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
|
319 |
+
|
320 |
+
####################
|
321 |
+
# Basic blocks
|
322 |
+
####################
|
323 |
+
|
324 |
+
|
325 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
326 |
+
"""Make layers by stacking the same blocks.
|
327 |
+
Args:
|
328 |
+
basic_block (nn.module): nn.module class for basic block. (block)
|
329 |
+
num_basic_block (int): number of blocks. (n_layers)
|
330 |
+
Returns:
|
331 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
332 |
+
"""
|
333 |
+
layers = []
|
334 |
+
for _ in range(num_basic_block):
|
335 |
+
layers.append(basic_block(**kwarg))
|
336 |
+
return nn.Sequential(*layers)
|
337 |
+
|
338 |
+
|
339 |
+
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
340 |
+
""" activation helper """
|
341 |
+
act_type = act_type.lower()
|
342 |
+
if act_type == 'relu':
|
343 |
+
layer = nn.ReLU(inplace)
|
344 |
+
elif act_type in ('leakyrelu', 'lrelu'):
|
345 |
+
layer = nn.LeakyReLU(neg_slope, inplace)
|
346 |
+
elif act_type == 'prelu':
|
347 |
+
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
348 |
+
elif act_type == 'tanh': # [-1, 1] range output
|
349 |
+
layer = nn.Tanh()
|
350 |
+
elif act_type == 'sigmoid': # [0, 1] range output
|
351 |
+
layer = nn.Sigmoid()
|
352 |
+
else:
|
353 |
+
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
354 |
+
return layer
|
355 |
+
|
356 |
+
|
357 |
+
class Identity(nn.Module):
|
358 |
+
def __init__(self, *kwargs):
|
359 |
+
super(Identity, self).__init__()
|
360 |
+
|
361 |
+
def forward(self, x, *kwargs):
|
362 |
+
return x
|
363 |
+
|
364 |
+
|
365 |
+
def norm(norm_type, nc):
|
366 |
+
""" Return a normalization layer """
|
367 |
+
norm_type = norm_type.lower()
|
368 |
+
if norm_type == 'batch':
|
369 |
+
layer = nn.BatchNorm2d(nc, affine=True)
|
370 |
+
elif norm_type == 'instance':
|
371 |
+
layer = nn.InstanceNorm2d(nc, affine=False)
|
372 |
+
elif norm_type == 'none':
|
373 |
+
def norm_layer(x): return Identity()
|
374 |
+
else:
|
375 |
+
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
376 |
+
return layer
|
377 |
+
|
378 |
+
|
379 |
+
def pad(pad_type, padding):
|
380 |
+
""" padding layer helper """
|
381 |
+
pad_type = pad_type.lower()
|
382 |
+
if padding == 0:
|
383 |
+
return None
|
384 |
+
if pad_type == 'reflect':
|
385 |
+
layer = nn.ReflectionPad2d(padding)
|
386 |
+
elif pad_type == 'replicate':
|
387 |
+
layer = nn.ReplicationPad2d(padding)
|
388 |
+
elif pad_type == 'zero':
|
389 |
+
layer = nn.ZeroPad2d(padding)
|
390 |
+
else:
|
391 |
+
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
392 |
+
return layer
|
393 |
+
|
394 |
+
|
395 |
+
def get_valid_padding(kernel_size, dilation):
|
396 |
+
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
397 |
+
padding = (kernel_size - 1) // 2
|
398 |
+
return padding
|
399 |
+
|
400 |
+
|
401 |
+
class ShortcutBlock(nn.Module):
|
402 |
+
""" Elementwise sum the output of a submodule to its input """
|
403 |
+
def __init__(self, submodule):
|
404 |
+
super(ShortcutBlock, self).__init__()
|
405 |
+
self.sub = submodule
|
406 |
+
|
407 |
+
def forward(self, x):
|
408 |
+
output = x + self.sub(x)
|
409 |
+
return output
|
410 |
+
|
411 |
+
def __repr__(self):
|
412 |
+
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
413 |
+
|
414 |
+
|
415 |
+
def sequential(*args):
|
416 |
+
""" Flatten Sequential. It unwraps nn.Sequential. """
|
417 |
+
if len(args) == 1:
|
418 |
+
if isinstance(args[0], OrderedDict):
|
419 |
+
raise NotImplementedError('sequential does not support OrderedDict input.')
|
420 |
+
return args[0] # No sequential is needed.
|
421 |
+
modules = []
|
422 |
+
for module in args:
|
423 |
+
if isinstance(module, nn.Sequential):
|
424 |
+
for submodule in module.children():
|
425 |
+
modules.append(submodule)
|
426 |
+
elif isinstance(module, nn.Module):
|
427 |
+
modules.append(module)
|
428 |
+
return nn.Sequential(*modules)
|
429 |
+
|
430 |
+
|
431 |
+
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
432 |
+
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
433 |
+
spectral_norm=False):
|
434 |
+
""" Conv layer with padding, normalization, activation """
|
435 |
+
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
436 |
+
padding = get_valid_padding(kernel_size, dilation)
|
437 |
+
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
438 |
+
padding = padding if pad_type == 'zero' else 0
|
439 |
+
|
440 |
+
if convtype=='PartialConv2D':
|
441 |
+
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
442 |
+
dilation=dilation, bias=bias, groups=groups)
|
443 |
+
elif convtype=='DeformConv2D':
|
444 |
+
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
445 |
+
dilation=dilation, bias=bias, groups=groups)
|
446 |
+
elif convtype=='Conv3D':
|
447 |
+
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
448 |
+
dilation=dilation, bias=bias, groups=groups)
|
449 |
+
else:
|
450 |
+
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
451 |
+
dilation=dilation, bias=bias, groups=groups)
|
452 |
+
|
453 |
+
if spectral_norm:
|
454 |
+
c = nn.utils.spectral_norm(c)
|
455 |
+
|
456 |
+
a = act(act_type) if act_type else None
|
457 |
+
if 'CNA' in mode:
|
458 |
+
n = norm(norm_type, out_nc) if norm_type else None
|
459 |
+
return sequential(p, c, n, a)
|
460 |
+
elif mode == 'NAC':
|
461 |
+
if norm_type is None and act_type is not None:
|
462 |
+
a = act(act_type, inplace=False)
|
463 |
+
n = norm(norm_type, in_nc) if norm_type else None
|
464 |
+
return sequential(n, a, p, c)
|
sd/stable-diffusion-webui/modules/extensions.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import time
|
6 |
+
import git
|
7 |
+
|
8 |
+
from modules import paths, shared
|
9 |
+
|
10 |
+
extensions = []
|
11 |
+
extensions_dir = os.path.join(paths.data_path, "extensions")
|
12 |
+
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
13 |
+
|
14 |
+
if not os.path.exists(extensions_dir):
|
15 |
+
os.makedirs(extensions_dir)
|
16 |
+
|
17 |
+
def active():
|
18 |
+
return [x for x in extensions if x.enabled]
|
19 |
+
|
20 |
+
|
21 |
+
class Extension:
|
22 |
+
def __init__(self, name, path, enabled=True, is_builtin=False):
|
23 |
+
self.name = name
|
24 |
+
self.path = path
|
25 |
+
self.enabled = enabled
|
26 |
+
self.status = ''
|
27 |
+
self.can_update = False
|
28 |
+
self.is_builtin = is_builtin
|
29 |
+
self.version = ''
|
30 |
+
|
31 |
+
repo = None
|
32 |
+
try:
|
33 |
+
if os.path.exists(os.path.join(path, ".git")):
|
34 |
+
repo = git.Repo(path)
|
35 |
+
except Exception:
|
36 |
+
print(f"Error reading github repository info from {path}:", file=sys.stderr)
|
37 |
+
print(traceback.format_exc(), file=sys.stderr)
|
38 |
+
|
39 |
+
if repo is None or repo.bare:
|
40 |
+
self.remote = None
|
41 |
+
else:
|
42 |
+
try:
|
43 |
+
self.remote = next(repo.remote().urls, None)
|
44 |
+
self.status = 'unknown'
|
45 |
+
head = repo.head.commit
|
46 |
+
ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
|
47 |
+
self.version = f'{head.hexsha[:8]} ({ts})'
|
48 |
+
|
49 |
+
except Exception:
|
50 |
+
self.remote = None
|
51 |
+
|
52 |
+
def list_files(self, subdir, extension):
|
53 |
+
from modules import scripts
|
54 |
+
|
55 |
+
dirpath = os.path.join(self.path, subdir)
|
56 |
+
if not os.path.isdir(dirpath):
|
57 |
+
return []
|
58 |
+
|
59 |
+
res = []
|
60 |
+
for filename in sorted(os.listdir(dirpath)):
|
61 |
+
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
|
62 |
+
|
63 |
+
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
64 |
+
|
65 |
+
return res
|
66 |
+
|
67 |
+
def check_updates(self):
|
68 |
+
repo = git.Repo(self.path)
|
69 |
+
for fetch in repo.remote().fetch("--dry-run"):
|
70 |
+
if fetch.flags != fetch.HEAD_UPTODATE:
|
71 |
+
self.can_update = True
|
72 |
+
self.status = "behind"
|
73 |
+
return
|
74 |
+
|
75 |
+
self.can_update = False
|
76 |
+
self.status = "latest"
|
77 |
+
|
78 |
+
def fetch_and_reset_hard(self):
|
79 |
+
repo = git.Repo(self.path)
|
80 |
+
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
81 |
+
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
82 |
+
repo.git.fetch('--all')
|
83 |
+
repo.git.reset('--hard', 'origin')
|
84 |
+
|
85 |
+
|
86 |
+
def list_extensions():
|
87 |
+
extensions.clear()
|
88 |
+
|
89 |
+
if not os.path.isdir(extensions_dir):
|
90 |
+
return
|
91 |
+
|
92 |
+
paths = []
|
93 |
+
for dirname in [extensions_dir, extensions_builtin_dir]:
|
94 |
+
if not os.path.isdir(dirname):
|
95 |
+
return
|
96 |
+
|
97 |
+
for extension_dirname in sorted(os.listdir(dirname)):
|
98 |
+
path = os.path.join(dirname, extension_dirname)
|
99 |
+
if not os.path.isdir(path):
|
100 |
+
continue
|
101 |
+
|
102 |
+
paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
103 |
+
|
104 |
+
for dirname, path, is_builtin in paths:
|
105 |
+
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
106 |
+
extensions.append(extension)
|
107 |
+
|
sd/stable-diffusion-webui/modules/extra_networks.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
from modules import errors
|
5 |
+
|
6 |
+
extra_network_registry = {}
|
7 |
+
|
8 |
+
|
9 |
+
def initialize():
|
10 |
+
extra_network_registry.clear()
|
11 |
+
|
12 |
+
|
13 |
+
def register_extra_network(extra_network):
|
14 |
+
extra_network_registry[extra_network.name] = extra_network
|
15 |
+
|
16 |
+
|
17 |
+
class ExtraNetworkParams:
|
18 |
+
def __init__(self, items=None):
|
19 |
+
self.items = items or []
|
20 |
+
|
21 |
+
|
22 |
+
class ExtraNetwork:
|
23 |
+
def __init__(self, name):
|
24 |
+
self.name = name
|
25 |
+
|
26 |
+
def activate(self, p, params_list):
|
27 |
+
"""
|
28 |
+
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
|
29 |
+
Passes arguments related to this extra network in params_list.
|
30 |
+
User passes arguments by specifying this in his prompt:
|
31 |
+
|
32 |
+
<name:arg1:arg2:arg3>
|
33 |
+
|
34 |
+
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
|
35 |
+
separated by colon.
|
36 |
+
|
37 |
+
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
|
38 |
+
in this case, all effects of this extra networks should be disabled.
|
39 |
+
|
40 |
+
Can be called multiple times before deactivate() - each new call should override the previous call completely.
|
41 |
+
|
42 |
+
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
|
43 |
+
|
44 |
+
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
|
45 |
+
|
46 |
+
params_list will be:
|
47 |
+
|
48 |
+
[
|
49 |
+
ExtraNetworkParams(items=["agm", "1.1"]),
|
50 |
+
ExtraNetworkParams(items=["ray"])
|
51 |
+
]
|
52 |
+
|
53 |
+
"""
|
54 |
+
raise NotImplementedError
|
55 |
+
|
56 |
+
def deactivate(self, p):
|
57 |
+
"""
|
58 |
+
Called at the end of processing for housekeeping. No need to do anything here.
|
59 |
+
"""
|
60 |
+
|
61 |
+
raise NotImplementedError
|
62 |
+
|
63 |
+
|
64 |
+
def activate(p, extra_network_data):
|
65 |
+
"""call activate for extra networks in extra_network_data in specified order, then call
|
66 |
+
activate for all remaining registered networks with an empty argument list"""
|
67 |
+
|
68 |
+
for extra_network_name, extra_network_args in extra_network_data.items():
|
69 |
+
extra_network = extra_network_registry.get(extra_network_name, None)
|
70 |
+
if extra_network is None:
|
71 |
+
print(f"Skipping unknown extra network: {extra_network_name}")
|
72 |
+
continue
|
73 |
+
|
74 |
+
try:
|
75 |
+
extra_network.activate(p, extra_network_args)
|
76 |
+
except Exception as e:
|
77 |
+
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
|
78 |
+
|
79 |
+
for extra_network_name, extra_network in extra_network_registry.items():
|
80 |
+
args = extra_network_data.get(extra_network_name, None)
|
81 |
+
if args is not None:
|
82 |
+
continue
|
83 |
+
|
84 |
+
try:
|
85 |
+
extra_network.activate(p, [])
|
86 |
+
except Exception as e:
|
87 |
+
errors.display(e, f"activating extra network {extra_network_name}")
|
88 |
+
|
89 |
+
|
90 |
+
def deactivate(p, extra_network_data):
|
91 |
+
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
92 |
+
deactivate for all remaining registered networks"""
|
93 |
+
|
94 |
+
for extra_network_name, extra_network_args in extra_network_data.items():
|
95 |
+
extra_network = extra_network_registry.get(extra_network_name, None)
|
96 |
+
if extra_network is None:
|
97 |
+
continue
|
98 |
+
|
99 |
+
try:
|
100 |
+
extra_network.deactivate(p)
|
101 |
+
except Exception as e:
|
102 |
+
errors.display(e, f"deactivating extra network {extra_network_name}")
|
103 |
+
|
104 |
+
for extra_network_name, extra_network in extra_network_registry.items():
|
105 |
+
args = extra_network_data.get(extra_network_name, None)
|
106 |
+
if args is not None:
|
107 |
+
continue
|
108 |
+
|
109 |
+
try:
|
110 |
+
extra_network.deactivate(p)
|
111 |
+
except Exception as e:
|
112 |
+
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
|
113 |
+
|
114 |
+
|
115 |
+
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
|
116 |
+
|
117 |
+
|
118 |
+
def parse_prompt(prompt):
|
119 |
+
res = defaultdict(list)
|
120 |
+
|
121 |
+
def found(m):
|
122 |
+
name = m.group(1)
|
123 |
+
args = m.group(2)
|
124 |
+
|
125 |
+
res[name].append(ExtraNetworkParams(items=args.split(":")))
|
126 |
+
|
127 |
+
return ""
|
128 |
+
|
129 |
+
prompt = re.sub(re_extra_net, found, prompt)
|
130 |
+
|
131 |
+
return prompt, res
|
132 |
+
|
133 |
+
|
134 |
+
def parse_prompts(prompts):
|
135 |
+
res = []
|
136 |
+
extra_data = None
|
137 |
+
|
138 |
+
for prompt in prompts:
|
139 |
+
updated_prompt, parsed_extra_data = parse_prompt(prompt)
|
140 |
+
|
141 |
+
if extra_data is None:
|
142 |
+
extra_data = parsed_extra_data
|
143 |
+
|
144 |
+
res.append(updated_prompt)
|
145 |
+
|
146 |
+
return res, extra_data
|
147 |
+
|
sd/stable-diffusion-webui/modules/extra_networks_hypernet.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules import extra_networks, shared, extra_networks
|
2 |
+
from modules.hypernetworks import hypernetwork
|
3 |
+
|
4 |
+
|
5 |
+
class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__('hypernet')
|
8 |
+
|
9 |
+
def activate(self, p, params_list):
|
10 |
+
additional = shared.opts.sd_hypernetwork
|
11 |
+
|
12 |
+
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
13 |
+
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
14 |
+
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
15 |
+
|
16 |
+
names = []
|
17 |
+
multipliers = []
|
18 |
+
for params in params_list:
|
19 |
+
assert len(params.items) > 0
|
20 |
+
|
21 |
+
names.append(params.items[0])
|
22 |
+
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
23 |
+
|
24 |
+
hypernetwork.load_hypernetworks(names, multipliers)
|
25 |
+
|
26 |
+
def deactivate(self, p):
|
27 |
+
pass
|
sd/stable-diffusion-webui/modules/extras.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import tqdm
|
8 |
+
|
9 |
+
from modules import shared, images, sd_models, sd_vae, sd_models_config
|
10 |
+
from modules.ui_common import plaintext_to_html
|
11 |
+
import gradio as gr
|
12 |
+
import safetensors.torch
|
13 |
+
|
14 |
+
|
15 |
+
def run_pnginfo(image):
|
16 |
+
if image is None:
|
17 |
+
return '', '', ''
|
18 |
+
|
19 |
+
geninfo, items = images.read_info_from_image(image)
|
20 |
+
items = {**{'parameters': geninfo}, **items}
|
21 |
+
|
22 |
+
info = ''
|
23 |
+
for key, text in items.items():
|
24 |
+
info += f"""
|
25 |
+
<div>
|
26 |
+
<p><b>{plaintext_to_html(str(key))}</b></p>
|
27 |
+
<p>{plaintext_to_html(str(text))}</p>
|
28 |
+
</div>
|
29 |
+
""".strip()+"\n"
|
30 |
+
|
31 |
+
if len(info) == 0:
|
32 |
+
message = "Nothing found in the image."
|
33 |
+
info = f"<div><p>{message}<p></div>"
|
34 |
+
|
35 |
+
return '', geninfo, info
|
36 |
+
|
37 |
+
|
38 |
+
def create_config(ckpt_result, config_source, a, b, c):
|
39 |
+
def config(x):
|
40 |
+
res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
|
41 |
+
return res if res != shared.sd_default_config else None
|
42 |
+
|
43 |
+
if config_source == 0:
|
44 |
+
cfg = config(a) or config(b) or config(c)
|
45 |
+
elif config_source == 1:
|
46 |
+
cfg = config(b)
|
47 |
+
elif config_source == 2:
|
48 |
+
cfg = config(c)
|
49 |
+
else:
|
50 |
+
cfg = None
|
51 |
+
|
52 |
+
if cfg is None:
|
53 |
+
return
|
54 |
+
|
55 |
+
filename, _ = os.path.splitext(ckpt_result)
|
56 |
+
checkpoint_filename = filename + ".yaml"
|
57 |
+
|
58 |
+
print("Copying config:")
|
59 |
+
print(" from:", cfg)
|
60 |
+
print(" to:", checkpoint_filename)
|
61 |
+
shutil.copyfile(cfg, checkpoint_filename)
|
62 |
+
|
63 |
+
|
64 |
+
checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
65 |
+
|
66 |
+
|
67 |
+
def to_half(tensor, enable):
|
68 |
+
if enable and tensor.dtype == torch.float:
|
69 |
+
return tensor.half()
|
70 |
+
|
71 |
+
return tensor
|
72 |
+
|
73 |
+
|
74 |
+
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
|
75 |
+
shared.state.begin()
|
76 |
+
shared.state.job = 'model-merge'
|
77 |
+
|
78 |
+
def fail(message):
|
79 |
+
shared.state.textinfo = message
|
80 |
+
shared.state.end()
|
81 |
+
return [*[gr.update() for _ in range(4)], message]
|
82 |
+
|
83 |
+
def weighted_sum(theta0, theta1, alpha):
|
84 |
+
return ((1 - alpha) * theta0) + (alpha * theta1)
|
85 |
+
|
86 |
+
def get_difference(theta1, theta2):
|
87 |
+
return theta1 - theta2
|
88 |
+
|
89 |
+
def add_difference(theta0, theta1_2_diff, alpha):
|
90 |
+
return theta0 + (alpha * theta1_2_diff)
|
91 |
+
|
92 |
+
def filename_weighted_sum():
|
93 |
+
a = primary_model_info.model_name
|
94 |
+
b = secondary_model_info.model_name
|
95 |
+
Ma = round(1 - multiplier, 2)
|
96 |
+
Mb = round(multiplier, 2)
|
97 |
+
|
98 |
+
return f"{Ma}({a}) + {Mb}({b})"
|
99 |
+
|
100 |
+
def filename_add_difference():
|
101 |
+
a = primary_model_info.model_name
|
102 |
+
b = secondary_model_info.model_name
|
103 |
+
c = tertiary_model_info.model_name
|
104 |
+
M = round(multiplier, 2)
|
105 |
+
|
106 |
+
return f"{a} + {M}({b} - {c})"
|
107 |
+
|
108 |
+
def filename_nothing():
|
109 |
+
return primary_model_info.model_name
|
110 |
+
|
111 |
+
theta_funcs = {
|
112 |
+
"Weighted sum": (filename_weighted_sum, None, weighted_sum),
|
113 |
+
"Add difference": (filename_add_difference, get_difference, add_difference),
|
114 |
+
"No interpolation": (filename_nothing, None, None),
|
115 |
+
}
|
116 |
+
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
|
117 |
+
shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
|
118 |
+
|
119 |
+
if not primary_model_name:
|
120 |
+
return fail("Failed: Merging requires a primary model.")
|
121 |
+
|
122 |
+
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
123 |
+
|
124 |
+
if theta_func2 and not secondary_model_name:
|
125 |
+
return fail("Failed: Merging requires a secondary model.")
|
126 |
+
|
127 |
+
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
|
128 |
+
|
129 |
+
if theta_func1 and not tertiary_model_name:
|
130 |
+
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
|
131 |
+
|
132 |
+
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
|
133 |
+
|
134 |
+
result_is_inpainting_model = False
|
135 |
+
result_is_instruct_pix2pix_model = False
|
136 |
+
|
137 |
+
if theta_func2:
|
138 |
+
shared.state.textinfo = f"Loading B"
|
139 |
+
print(f"Loading {secondary_model_info.filename}...")
|
140 |
+
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
141 |
+
else:
|
142 |
+
theta_1 = None
|
143 |
+
|
144 |
+
if theta_func1:
|
145 |
+
shared.state.textinfo = f"Loading C"
|
146 |
+
print(f"Loading {tertiary_model_info.filename}...")
|
147 |
+
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
148 |
+
|
149 |
+
shared.state.textinfo = 'Merging B and C'
|
150 |
+
shared.state.sampling_steps = len(theta_1.keys())
|
151 |
+
for key in tqdm.tqdm(theta_1.keys()):
|
152 |
+
if key in checkpoint_dict_skip_on_merge:
|
153 |
+
continue
|
154 |
+
|
155 |
+
if 'model' in key:
|
156 |
+
if key in theta_2:
|
157 |
+
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
158 |
+
theta_1[key] = theta_func1(theta_1[key], t2)
|
159 |
+
else:
|
160 |
+
theta_1[key] = torch.zeros_like(theta_1[key])
|
161 |
+
|
162 |
+
shared.state.sampling_step += 1
|
163 |
+
del theta_2
|
164 |
+
|
165 |
+
shared.state.nextjob()
|
166 |
+
|
167 |
+
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
168 |
+
print(f"Loading {primary_model_info.filename}...")
|
169 |
+
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
170 |
+
|
171 |
+
print("Merging...")
|
172 |
+
shared.state.textinfo = 'Merging A and B'
|
173 |
+
shared.state.sampling_steps = len(theta_0.keys())
|
174 |
+
for key in tqdm.tqdm(theta_0.keys()):
|
175 |
+
if theta_1 and 'model' in key and key in theta_1:
|
176 |
+
|
177 |
+
if key in checkpoint_dict_skip_on_merge:
|
178 |
+
continue
|
179 |
+
|
180 |
+
a = theta_0[key]
|
181 |
+
b = theta_1[key]
|
182 |
+
|
183 |
+
# this enables merging an inpainting model (A) with another one (B);
|
184 |
+
# where normal model would have 4 channels, for latenst space, inpainting model would
|
185 |
+
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
186 |
+
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
|
187 |
+
if a.shape[1] == 4 and b.shape[1] == 9:
|
188 |
+
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
|
189 |
+
if a.shape[1] == 4 and b.shape[1] == 8:
|
190 |
+
raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
|
191 |
+
|
192 |
+
if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
|
193 |
+
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
|
194 |
+
result_is_instruct_pix2pix_model = True
|
195 |
+
else:
|
196 |
+
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
197 |
+
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
198 |
+
result_is_inpainting_model = True
|
199 |
+
else:
|
200 |
+
theta_0[key] = theta_func2(a, b, multiplier)
|
201 |
+
|
202 |
+
theta_0[key] = to_half(theta_0[key], save_as_half)
|
203 |
+
|
204 |
+
shared.state.sampling_step += 1
|
205 |
+
|
206 |
+
del theta_1
|
207 |
+
|
208 |
+
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
|
209 |
+
if bake_in_vae_filename is not None:
|
210 |
+
print(f"Baking in VAE from {bake_in_vae_filename}")
|
211 |
+
shared.state.textinfo = 'Baking in VAE'
|
212 |
+
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
|
213 |
+
|
214 |
+
for key in vae_dict.keys():
|
215 |
+
theta_0_key = 'first_stage_model.' + key
|
216 |
+
if theta_0_key in theta_0:
|
217 |
+
theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
|
218 |
+
|
219 |
+
del vae_dict
|
220 |
+
|
221 |
+
if save_as_half and not theta_func2:
|
222 |
+
for key in theta_0.keys():
|
223 |
+
theta_0[key] = to_half(theta_0[key], save_as_half)
|
224 |
+
|
225 |
+
if discard_weights:
|
226 |
+
regex = re.compile(discard_weights)
|
227 |
+
for key in list(theta_0):
|
228 |
+
if re.search(regex, key):
|
229 |
+
theta_0.pop(key, None)
|
230 |
+
|
231 |
+
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
232 |
+
|
233 |
+
filename = filename_generator() if custom_name == '' else custom_name
|
234 |
+
filename += ".inpainting" if result_is_inpainting_model else ""
|
235 |
+
filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
|
236 |
+
filename += "." + checkpoint_format
|
237 |
+
|
238 |
+
output_modelname = os.path.join(ckpt_dir, filename)
|
239 |
+
|
240 |
+
shared.state.nextjob()
|
241 |
+
shared.state.textinfo = "Saving"
|
242 |
+
print(f"Saving to {output_modelname}...")
|
243 |
+
|
244 |
+
_, extension = os.path.splitext(output_modelname)
|
245 |
+
if extension.lower() == ".safetensors":
|
246 |
+
safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
|
247 |
+
else:
|
248 |
+
torch.save(theta_0, output_modelname)
|
249 |
+
|
250 |
+
sd_models.list_models()
|
251 |
+
|
252 |
+
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
253 |
+
|
254 |
+
print(f"Checkpoint saved to {output_modelname}.")
|
255 |
+
shared.state.textinfo = "Checkpoint saved"
|
256 |
+
shared.state.end()
|
257 |
+
|
258 |
+
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
sd/stable-diffusion-webui/modules/face_restoration.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules import shared
|
2 |
+
|
3 |
+
|
4 |
+
class FaceRestoration:
|
5 |
+
def name(self):
|
6 |
+
return "None"
|
7 |
+
|
8 |
+
def restore(self, np_image):
|
9 |
+
return np_image
|
10 |
+
|
11 |
+
|
12 |
+
def restore_faces(np_image):
|
13 |
+
face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
|
14 |
+
if len(face_restorers) == 0:
|
15 |
+
return np_image
|
16 |
+
|
17 |
+
face_restorer = face_restorers[0]
|
18 |
+
|
19 |
+
return face_restorer.restore(np_image)
|
sd/stable-diffusion-webui/modules/generation_parameters_copypaste.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import html
|
3 |
+
import io
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
from modules.paths import data_path
|
11 |
+
from modules import shared, ui_tempdir, script_callbacks
|
12 |
+
import tempfile
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
16 |
+
re_param = re.compile(re_param_code)
|
17 |
+
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
18 |
+
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
19 |
+
type_of_gr_update = type(gr.update())
|
20 |
+
|
21 |
+
paste_fields = {}
|
22 |
+
registered_param_bindings = []
|
23 |
+
|
24 |
+
|
25 |
+
class ParamBinding:
|
26 |
+
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
|
27 |
+
self.paste_button = paste_button
|
28 |
+
self.tabname = tabname
|
29 |
+
self.source_text_component = source_text_component
|
30 |
+
self.source_image_component = source_image_component
|
31 |
+
self.source_tabname = source_tabname
|
32 |
+
self.override_settings_component = override_settings_component
|
33 |
+
|
34 |
+
|
35 |
+
def reset():
|
36 |
+
paste_fields.clear()
|
37 |
+
|
38 |
+
|
39 |
+
def quote(text):
|
40 |
+
if ',' not in str(text):
|
41 |
+
return text
|
42 |
+
|
43 |
+
text = str(text)
|
44 |
+
text = text.replace('\\', '\\\\')
|
45 |
+
text = text.replace('"', '\\"')
|
46 |
+
return f'"{text}"'
|
47 |
+
|
48 |
+
|
49 |
+
def image_from_url_text(filedata):
|
50 |
+
if filedata is None:
|
51 |
+
return None
|
52 |
+
|
53 |
+
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
|
54 |
+
filedata = filedata[0]
|
55 |
+
|
56 |
+
if type(filedata) == dict and filedata.get("is_file", False):
|
57 |
+
filename = filedata["name"]
|
58 |
+
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
59 |
+
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
60 |
+
|
61 |
+
return Image.open(filename)
|
62 |
+
|
63 |
+
if type(filedata) == list:
|
64 |
+
if len(filedata) == 0:
|
65 |
+
return None
|
66 |
+
|
67 |
+
filedata = filedata[0]
|
68 |
+
|
69 |
+
if filedata.startswith("data:image/png;base64,"):
|
70 |
+
filedata = filedata[len("data:image/png;base64,"):]
|
71 |
+
|
72 |
+
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
73 |
+
image = Image.open(io.BytesIO(filedata))
|
74 |
+
return image
|
75 |
+
|
76 |
+
|
77 |
+
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
|
78 |
+
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
|
79 |
+
|
80 |
+
# backwards compatibility for existing extensions
|
81 |
+
import modules.ui
|
82 |
+
if tabname == 'txt2img':
|
83 |
+
modules.ui.txt2img_paste_fields = fields
|
84 |
+
elif tabname == 'img2img':
|
85 |
+
modules.ui.img2img_paste_fields = fields
|
86 |
+
|
87 |
+
|
88 |
+
def create_buttons(tabs_list):
|
89 |
+
buttons = {}
|
90 |
+
for tab in tabs_list:
|
91 |
+
buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
|
92 |
+
return buttons
|
93 |
+
|
94 |
+
|
95 |
+
def bind_buttons(buttons, send_image, send_generate_info):
|
96 |
+
"""old function for backwards compatibility; do not use this, use register_paste_params_button"""
|
97 |
+
for tabname, button in buttons.items():
|
98 |
+
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
|
99 |
+
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
|
100 |
+
|
101 |
+
register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
|
102 |
+
|
103 |
+
|
104 |
+
def register_paste_params_button(binding: ParamBinding):
|
105 |
+
registered_param_bindings.append(binding)
|
106 |
+
|
107 |
+
|
108 |
+
def connect_paste_params_buttons():
|
109 |
+
binding: ParamBinding
|
110 |
+
for binding in registered_param_bindings:
|
111 |
+
destination_image_component = paste_fields[binding.tabname]["init_img"]
|
112 |
+
fields = paste_fields[binding.tabname]["fields"]
|
113 |
+
override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
|
114 |
+
|
115 |
+
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
116 |
+
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
117 |
+
|
118 |
+
if binding.source_image_component and destination_image_component:
|
119 |
+
if isinstance(binding.source_image_component, gr.Gallery):
|
120 |
+
func = send_image_and_dimensions if destination_width_component else image_from_url_text
|
121 |
+
jsfunc = "extract_image_from_gallery"
|
122 |
+
else:
|
123 |
+
func = send_image_and_dimensions if destination_width_component else lambda x: x
|
124 |
+
jsfunc = None
|
125 |
+
|
126 |
+
binding.paste_button.click(
|
127 |
+
fn=func,
|
128 |
+
_js=jsfunc,
|
129 |
+
inputs=[binding.source_image_component],
|
130 |
+
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
131 |
+
)
|
132 |
+
|
133 |
+
if binding.source_text_component is not None and fields is not None:
|
134 |
+
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
|
135 |
+
|
136 |
+
if binding.source_tabname is not None and fields is not None:
|
137 |
+
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
|
138 |
+
binding.paste_button.click(
|
139 |
+
fn=lambda *x: x,
|
140 |
+
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
141 |
+
outputs=[field for field, name in fields if name in paste_field_names],
|
142 |
+
)
|
143 |
+
|
144 |
+
binding.paste_button.click(
|
145 |
+
fn=None,
|
146 |
+
_js=f"switch_to_{binding.tabname}",
|
147 |
+
inputs=None,
|
148 |
+
outputs=None,
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
def send_image_and_dimensions(x):
|
153 |
+
if isinstance(x, Image.Image):
|
154 |
+
img = x
|
155 |
+
else:
|
156 |
+
img = image_from_url_text(x)
|
157 |
+
|
158 |
+
if shared.opts.send_size and isinstance(img, Image.Image):
|
159 |
+
w = img.width
|
160 |
+
h = img.height
|
161 |
+
else:
|
162 |
+
w = gr.update()
|
163 |
+
h = gr.update()
|
164 |
+
|
165 |
+
return img, w, h
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
|
170 |
+
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
|
171 |
+
|
172 |
+
Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
|
173 |
+
parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
|
174 |
+
|
175 |
+
If the infotext has no hash, then a hypernet with the same name will be selected instead.
|
176 |
+
"""
|
177 |
+
hypernet_name = hypernet_name.lower()
|
178 |
+
if hypernet_hash is not None:
|
179 |
+
# Try to match the hash in the name
|
180 |
+
for hypernet_key in shared.hypernetworks.keys():
|
181 |
+
result = re_hypernet_hash.search(hypernet_key)
|
182 |
+
if result is not None and result[1] == hypernet_hash:
|
183 |
+
return hypernet_key
|
184 |
+
else:
|
185 |
+
# Fall back to a hypernet with the same name
|
186 |
+
for hypernet_key in shared.hypernetworks.keys():
|
187 |
+
if hypernet_key.lower().startswith(hypernet_name):
|
188 |
+
return hypernet_key
|
189 |
+
|
190 |
+
return None
|
191 |
+
|
192 |
+
|
193 |
+
def restore_old_hires_fix_params(res):
|
194 |
+
"""for infotexts that specify old First pass size parameter, convert it into
|
195 |
+
width, height, and hr scale"""
|
196 |
+
|
197 |
+
firstpass_width = res.get('First pass size-1', None)
|
198 |
+
firstpass_height = res.get('First pass size-2', None)
|
199 |
+
|
200 |
+
if shared.opts.use_old_hires_fix_width_height:
|
201 |
+
hires_width = int(res.get("Hires resize-1", 0))
|
202 |
+
hires_height = int(res.get("Hires resize-2", 0))
|
203 |
+
|
204 |
+
if hires_width and hires_height:
|
205 |
+
res['Size-1'] = hires_width
|
206 |
+
res['Size-2'] = hires_height
|
207 |
+
return
|
208 |
+
|
209 |
+
if firstpass_width is None or firstpass_height is None:
|
210 |
+
return
|
211 |
+
|
212 |
+
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
|
213 |
+
width = int(res.get("Size-1", 512))
|
214 |
+
height = int(res.get("Size-2", 512))
|
215 |
+
|
216 |
+
if firstpass_width == 0 or firstpass_height == 0:
|
217 |
+
from modules import processing
|
218 |
+
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
|
219 |
+
|
220 |
+
res['Size-1'] = firstpass_width
|
221 |
+
res['Size-2'] = firstpass_height
|
222 |
+
res['Hires resize-1'] = width
|
223 |
+
res['Hires resize-2'] = height
|
224 |
+
|
225 |
+
|
226 |
+
def parse_generation_parameters(x: str):
|
227 |
+
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
228 |
+
```
|
229 |
+
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
|
230 |
+
Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
|
231 |
+
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
|
232 |
+
```
|
233 |
+
|
234 |
+
returns a dict with field values
|
235 |
+
"""
|
236 |
+
|
237 |
+
res = {}
|
238 |
+
|
239 |
+
prompt = ""
|
240 |
+
negative_prompt = ""
|
241 |
+
|
242 |
+
done_with_prompt = False
|
243 |
+
|
244 |
+
*lines, lastline = x.strip().split("\n")
|
245 |
+
if len(re_param.findall(lastline)) < 3:
|
246 |
+
lines.append(lastline)
|
247 |
+
lastline = ''
|
248 |
+
|
249 |
+
for i, line in enumerate(lines):
|
250 |
+
line = line.strip()
|
251 |
+
if line.startswith("Negative prompt:"):
|
252 |
+
done_with_prompt = True
|
253 |
+
line = line[16:].strip()
|
254 |
+
|
255 |
+
if done_with_prompt:
|
256 |
+
negative_prompt += ("" if negative_prompt == "" else "\n") + line
|
257 |
+
else:
|
258 |
+
prompt += ("" if prompt == "" else "\n") + line
|
259 |
+
|
260 |
+
res["Prompt"] = prompt
|
261 |
+
res["Negative prompt"] = negative_prompt
|
262 |
+
|
263 |
+
for k, v in re_param.findall(lastline):
|
264 |
+
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
265 |
+
m = re_imagesize.match(v)
|
266 |
+
if m is not None:
|
267 |
+
res[k+"-1"] = m.group(1)
|
268 |
+
res[k+"-2"] = m.group(2)
|
269 |
+
else:
|
270 |
+
res[k] = v
|
271 |
+
|
272 |
+
# Missing CLIP skip means it was set to 1 (the default)
|
273 |
+
if "Clip skip" not in res:
|
274 |
+
res["Clip skip"] = "1"
|
275 |
+
|
276 |
+
hypernet = res.get("Hypernet", None)
|
277 |
+
if hypernet is not None:
|
278 |
+
res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
|
279 |
+
|
280 |
+
if "Hires resize-1" not in res:
|
281 |
+
res["Hires resize-1"] = 0
|
282 |
+
res["Hires resize-2"] = 0
|
283 |
+
|
284 |
+
restore_old_hires_fix_params(res)
|
285 |
+
|
286 |
+
return res
|
287 |
+
|
288 |
+
|
289 |
+
settings_map = {}
|
290 |
+
|
291 |
+
infotext_to_setting_name_mapping = [
|
292 |
+
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
293 |
+
('Conditional mask weight', 'inpainting_mask_weight'),
|
294 |
+
('Model hash', 'sd_model_checkpoint'),
|
295 |
+
('ENSD', 'eta_noise_seed_delta'),
|
296 |
+
('Noise multiplier', 'initial_noise_multiplier'),
|
297 |
+
('Eta', 'eta_ancestral'),
|
298 |
+
('Eta DDIM', 'eta_ddim'),
|
299 |
+
('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
|
300 |
+
]
|
301 |
+
|
302 |
+
|
303 |
+
def create_override_settings_dict(text_pairs):
|
304 |
+
"""creates processing's override_settings parameters from gradio's multiselect
|
305 |
+
|
306 |
+
Example input:
|
307 |
+
['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
|
308 |
+
|
309 |
+
Example output:
|
310 |
+
{'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
|
311 |
+
"""
|
312 |
+
|
313 |
+
res = {}
|
314 |
+
|
315 |
+
params = {}
|
316 |
+
for pair in text_pairs:
|
317 |
+
k, v = pair.split(":", maxsplit=1)
|
318 |
+
|
319 |
+
params[k] = v.strip()
|
320 |
+
|
321 |
+
for param_name, setting_name in infotext_to_setting_name_mapping:
|
322 |
+
value = params.get(param_name, None)
|
323 |
+
|
324 |
+
if value is None:
|
325 |
+
continue
|
326 |
+
|
327 |
+
res[setting_name] = shared.opts.cast_value(setting_name, value)
|
328 |
+
|
329 |
+
return res
|
330 |
+
|
331 |
+
|
332 |
+
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
333 |
+
def paste_func(prompt):
|
334 |
+
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
335 |
+
filename = os.path.join(data_path, "params.txt")
|
336 |
+
if os.path.exists(filename):
|
337 |
+
with open(filename, "r", encoding="utf8") as file:
|
338 |
+
prompt = file.read()
|
339 |
+
|
340 |
+
params = parse_generation_parameters(prompt)
|
341 |
+
script_callbacks.infotext_pasted_callback(prompt, params)
|
342 |
+
res = []
|
343 |
+
|
344 |
+
for output, key in paste_fields:
|
345 |
+
if callable(key):
|
346 |
+
v = key(params)
|
347 |
+
else:
|
348 |
+
v = params.get(key, None)
|
349 |
+
|
350 |
+
if v is None:
|
351 |
+
res.append(gr.update())
|
352 |
+
elif isinstance(v, type_of_gr_update):
|
353 |
+
res.append(v)
|
354 |
+
else:
|
355 |
+
try:
|
356 |
+
valtype = type(output.value)
|
357 |
+
|
358 |
+
if valtype == bool and v == "False":
|
359 |
+
val = False
|
360 |
+
else:
|
361 |
+
val = valtype(v)
|
362 |
+
|
363 |
+
res.append(gr.update(value=val))
|
364 |
+
except Exception:
|
365 |
+
res.append(gr.update())
|
366 |
+
|
367 |
+
return res
|
368 |
+
|
369 |
+
if override_settings_component is not None:
|
370 |
+
def paste_settings(params):
|
371 |
+
vals = {}
|
372 |
+
|
373 |
+
for param_name, setting_name in infotext_to_setting_name_mapping:
|
374 |
+
v = params.get(param_name, None)
|
375 |
+
if v is None:
|
376 |
+
continue
|
377 |
+
|
378 |
+
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
|
379 |
+
continue
|
380 |
+
|
381 |
+
v = shared.opts.cast_value(setting_name, v)
|
382 |
+
current_value = getattr(shared.opts, setting_name, None)
|
383 |
+
|
384 |
+
if v == current_value:
|
385 |
+
continue
|
386 |
+
|
387 |
+
vals[param_name] = v
|
388 |
+
|
389 |
+
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
|
390 |
+
|
391 |
+
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
|
392 |
+
|
393 |
+
paste_fields = paste_fields + [(override_settings_component, paste_settings)]
|
394 |
+
|
395 |
+
button.click(
|
396 |
+
fn=paste_func,
|
397 |
+
_js=f"recalculate_prompts_{tabname}",
|
398 |
+
inputs=[input_comp],
|
399 |
+
outputs=[x[0] for x in paste_fields],
|
400 |
+
)
|
401 |
+
|
402 |
+
|
sd/stable-diffusion-webui/modules/gfpgan_model.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import facexlib
|
6 |
+
import gfpgan
|
7 |
+
|
8 |
+
import modules.face_restoration
|
9 |
+
from modules import paths, shared, devices, modelloader
|
10 |
+
|
11 |
+
model_dir = "GFPGAN"
|
12 |
+
user_path = None
|
13 |
+
model_path = os.path.join(paths.models_path, model_dir)
|
14 |
+
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
15 |
+
have_gfpgan = False
|
16 |
+
loaded_gfpgan_model = None
|
17 |
+
|
18 |
+
|
19 |
+
def gfpgann():
|
20 |
+
global loaded_gfpgan_model
|
21 |
+
global model_path
|
22 |
+
if loaded_gfpgan_model is not None:
|
23 |
+
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
24 |
+
return loaded_gfpgan_model
|
25 |
+
|
26 |
+
if gfpgan_constructor is None:
|
27 |
+
return None
|
28 |
+
|
29 |
+
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
30 |
+
if len(models) == 1 and "http" in models[0]:
|
31 |
+
model_file = models[0]
|
32 |
+
elif len(models) != 0:
|
33 |
+
latest_file = max(models, key=os.path.getctime)
|
34 |
+
model_file = latest_file
|
35 |
+
else:
|
36 |
+
print("Unable to load gfpgan model!")
|
37 |
+
return None
|
38 |
+
if hasattr(facexlib.detection.retinaface, 'device'):
|
39 |
+
facexlib.detection.retinaface.device = devices.device_gfpgan
|
40 |
+
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
41 |
+
loaded_gfpgan_model = model
|
42 |
+
|
43 |
+
return model
|
44 |
+
|
45 |
+
|
46 |
+
def send_model_to(model, device):
|
47 |
+
model.gfpgan.to(device)
|
48 |
+
model.face_helper.face_det.to(device)
|
49 |
+
model.face_helper.face_parse.to(device)
|
50 |
+
|
51 |
+
|
52 |
+
def gfpgan_fix_faces(np_image):
|
53 |
+
model = gfpgann()
|
54 |
+
if model is None:
|
55 |
+
return np_image
|
56 |
+
|
57 |
+
send_model_to(model, devices.device_gfpgan)
|
58 |
+
|
59 |
+
np_image_bgr = np_image[:, :, ::-1]
|
60 |
+
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
61 |
+
np_image = gfpgan_output_bgr[:, :, ::-1]
|
62 |
+
|
63 |
+
model.face_helper.clean_all()
|
64 |
+
|
65 |
+
if shared.opts.face_restoration_unload:
|
66 |
+
send_model_to(model, devices.cpu)
|
67 |
+
|
68 |
+
return np_image
|
69 |
+
|
70 |
+
|
71 |
+
gfpgan_constructor = None
|
72 |
+
|
73 |
+
|
74 |
+
def setup_model(dirname):
|
75 |
+
global model_path
|
76 |
+
if not os.path.exists(model_path):
|
77 |
+
os.makedirs(model_path)
|
78 |
+
|
79 |
+
try:
|
80 |
+
from gfpgan import GFPGANer
|
81 |
+
from facexlib import detection, parsing
|
82 |
+
global user_path
|
83 |
+
global have_gfpgan
|
84 |
+
global gfpgan_constructor
|
85 |
+
|
86 |
+
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
87 |
+
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
88 |
+
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
89 |
+
|
90 |
+
def my_load_file_from_url(**kwargs):
|
91 |
+
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
|
92 |
+
|
93 |
+
def facex_load_file_from_url(**kwargs):
|
94 |
+
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
|
95 |
+
|
96 |
+
def facex_load_file_from_url2(**kwargs):
|
97 |
+
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
|
98 |
+
|
99 |
+
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
100 |
+
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
101 |
+
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
102 |
+
user_path = dirname
|
103 |
+
have_gfpgan = True
|
104 |
+
gfpgan_constructor = GFPGANer
|
105 |
+
|
106 |
+
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
107 |
+
def name(self):
|
108 |
+
return "GFPGAN"
|
109 |
+
|
110 |
+
def restore(self, np_image):
|
111 |
+
return gfpgan_fix_faces(np_image)
|
112 |
+
|
113 |
+
shared.face_restorers.append(FaceRestorerGFPGAN())
|
114 |
+
except Exception:
|
115 |
+
print("Error setting up GFPGAN:", file=sys.stderr)
|
116 |
+
print(traceback.format_exc(), file=sys.stderr)
|
sd/stable-diffusion-webui/modules/hashes.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import json
|
3 |
+
import os.path
|
4 |
+
|
5 |
+
import filelock
|
6 |
+
|
7 |
+
from modules import shared
|
8 |
+
from modules.paths import data_path
|
9 |
+
|
10 |
+
|
11 |
+
cache_filename = os.path.join(data_path, "cache.json")
|
12 |
+
cache_data = None
|
13 |
+
|
14 |
+
|
15 |
+
def dump_cache():
|
16 |
+
with filelock.FileLock(cache_filename+".lock"):
|
17 |
+
with open(cache_filename, "w", encoding="utf8") as file:
|
18 |
+
json.dump(cache_data, file, indent=4)
|
19 |
+
|
20 |
+
|
21 |
+
def cache(subsection):
|
22 |
+
global cache_data
|
23 |
+
|
24 |
+
if cache_data is None:
|
25 |
+
with filelock.FileLock(cache_filename+".lock"):
|
26 |
+
if not os.path.isfile(cache_filename):
|
27 |
+
cache_data = {}
|
28 |
+
else:
|
29 |
+
with open(cache_filename, "r", encoding="utf8") as file:
|
30 |
+
cache_data = json.load(file)
|
31 |
+
|
32 |
+
s = cache_data.get(subsection, {})
|
33 |
+
cache_data[subsection] = s
|
34 |
+
|
35 |
+
return s
|
36 |
+
|
37 |
+
|
38 |
+
def calculate_sha256(filename):
|
39 |
+
hash_sha256 = hashlib.sha256()
|
40 |
+
blksize = 1024 * 1024
|
41 |
+
|
42 |
+
with open(filename, "rb") as f:
|
43 |
+
for chunk in iter(lambda: f.read(blksize), b""):
|
44 |
+
hash_sha256.update(chunk)
|
45 |
+
|
46 |
+
return hash_sha256.hexdigest()
|
47 |
+
|
48 |
+
|
49 |
+
def sha256_from_cache(filename, title):
|
50 |
+
hashes = cache("hashes")
|
51 |
+
ondisk_mtime = os.path.getmtime(filename)
|
52 |
+
|
53 |
+
if title not in hashes:
|
54 |
+
return None
|
55 |
+
|
56 |
+
cached_sha256 = hashes[title].get("sha256", None)
|
57 |
+
cached_mtime = hashes[title].get("mtime", 0)
|
58 |
+
|
59 |
+
if ondisk_mtime > cached_mtime or cached_sha256 is None:
|
60 |
+
return None
|
61 |
+
|
62 |
+
return cached_sha256
|
63 |
+
|
64 |
+
|
65 |
+
def sha256(filename, title):
|
66 |
+
hashes = cache("hashes")
|
67 |
+
|
68 |
+
sha256_value = sha256_from_cache(filename, title)
|
69 |
+
if sha256_value is not None:
|
70 |
+
return sha256_value
|
71 |
+
|
72 |
+
if shared.cmd_opts.no_hashing:
|
73 |
+
return None
|
74 |
+
|
75 |
+
print(f"Calculating sha256 for {filename}: ", end='')
|
76 |
+
sha256_value = calculate_sha256(filename)
|
77 |
+
print(f"{sha256_value}")
|
78 |
+
|
79 |
+
hashes[title] = {
|
80 |
+
"mtime": os.path.getmtime(filename),
|
81 |
+
"sha256": sha256_value,
|
82 |
+
}
|
83 |
+
|
84 |
+
dump_cache()
|
85 |
+
|
86 |
+
return sha256_value
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
sd/stable-diffusion-webui/modules/images.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import pytz
|
6 |
+
import io
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
from collections import namedtuple
|
10 |
+
import re
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import piexif
|
14 |
+
import piexif.helper
|
15 |
+
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
16 |
+
from fonts.ttf import Roboto
|
17 |
+
import string
|
18 |
+
import json
|
19 |
+
import hashlib
|
20 |
+
|
21 |
+
from modules import sd_samplers, shared, script_callbacks, errors
|
22 |
+
from modules.shared import opts, cmd_opts
|
23 |
+
|
24 |
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
25 |
+
|
26 |
+
|
27 |
+
def image_grid(imgs, batch_size=1, rows=None):
|
28 |
+
if rows is None:
|
29 |
+
if opts.n_rows > 0:
|
30 |
+
rows = opts.n_rows
|
31 |
+
elif opts.n_rows == 0:
|
32 |
+
rows = batch_size
|
33 |
+
elif opts.grid_prevent_empty_spots:
|
34 |
+
rows = math.floor(math.sqrt(len(imgs)))
|
35 |
+
while len(imgs) % rows != 0:
|
36 |
+
rows -= 1
|
37 |
+
else:
|
38 |
+
rows = math.sqrt(len(imgs))
|
39 |
+
rows = round(rows)
|
40 |
+
if rows > len(imgs):
|
41 |
+
rows = len(imgs)
|
42 |
+
|
43 |
+
cols = math.ceil(len(imgs) / rows)
|
44 |
+
|
45 |
+
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
|
46 |
+
script_callbacks.image_grid_callback(params)
|
47 |
+
|
48 |
+
w, h = imgs[0].size
|
49 |
+
grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
|
50 |
+
|
51 |
+
for i, img in enumerate(params.imgs):
|
52 |
+
grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
|
53 |
+
|
54 |
+
return grid
|
55 |
+
|
56 |
+
|
57 |
+
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
58 |
+
|
59 |
+
|
60 |
+
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
61 |
+
w = image.width
|
62 |
+
h = image.height
|
63 |
+
|
64 |
+
non_overlap_width = tile_w - overlap
|
65 |
+
non_overlap_height = tile_h - overlap
|
66 |
+
|
67 |
+
cols = math.ceil((w - overlap) / non_overlap_width)
|
68 |
+
rows = math.ceil((h - overlap) / non_overlap_height)
|
69 |
+
|
70 |
+
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
71 |
+
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
72 |
+
|
73 |
+
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
74 |
+
for row in range(rows):
|
75 |
+
row_images = []
|
76 |
+
|
77 |
+
y = int(row * dy)
|
78 |
+
|
79 |
+
if y + tile_h >= h:
|
80 |
+
y = h - tile_h
|
81 |
+
|
82 |
+
for col in range(cols):
|
83 |
+
x = int(col * dx)
|
84 |
+
|
85 |
+
if x + tile_w >= w:
|
86 |
+
x = w - tile_w
|
87 |
+
|
88 |
+
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
89 |
+
|
90 |
+
row_images.append([x, tile_w, tile])
|
91 |
+
|
92 |
+
grid.tiles.append([y, tile_h, row_images])
|
93 |
+
|
94 |
+
return grid
|
95 |
+
|
96 |
+
|
97 |
+
def combine_grid(grid):
|
98 |
+
def make_mask_image(r):
|
99 |
+
r = r * 255 / grid.overlap
|
100 |
+
r = r.astype(np.uint8)
|
101 |
+
return Image.fromarray(r, 'L')
|
102 |
+
|
103 |
+
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
104 |
+
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
105 |
+
|
106 |
+
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
107 |
+
for y, h, row in grid.tiles:
|
108 |
+
combined_row = Image.new("RGB", (grid.image_w, h))
|
109 |
+
for x, w, tile in row:
|
110 |
+
if x == 0:
|
111 |
+
combined_row.paste(tile, (0, 0))
|
112 |
+
continue
|
113 |
+
|
114 |
+
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
|
115 |
+
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
|
116 |
+
|
117 |
+
if y == 0:
|
118 |
+
combined_image.paste(combined_row, (0, 0))
|
119 |
+
continue
|
120 |
+
|
121 |
+
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
|
122 |
+
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
|
123 |
+
|
124 |
+
return combined_image
|
125 |
+
|
126 |
+
|
127 |
+
class GridAnnotation:
|
128 |
+
def __init__(self, text='', is_active=True):
|
129 |
+
self.text = text
|
130 |
+
self.is_active = is_active
|
131 |
+
self.size = None
|
132 |
+
|
133 |
+
|
134 |
+
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
135 |
+
def wrap(drawing, text, font, line_length):
|
136 |
+
lines = ['']
|
137 |
+
for word in text.split():
|
138 |
+
line = f'{lines[-1]} {word}'.strip()
|
139 |
+
if drawing.textlength(line, font=font) <= line_length:
|
140 |
+
lines[-1] = line
|
141 |
+
else:
|
142 |
+
lines.append(word)
|
143 |
+
return lines
|
144 |
+
|
145 |
+
def get_font(fontsize):
|
146 |
+
try:
|
147 |
+
return ImageFont.truetype(opts.font or Roboto, fontsize)
|
148 |
+
except Exception:
|
149 |
+
return ImageFont.truetype(Roboto, fontsize)
|
150 |
+
|
151 |
+
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
152 |
+
for i, line in enumerate(lines):
|
153 |
+
fnt = initial_fnt
|
154 |
+
fontsize = initial_fontsize
|
155 |
+
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
156 |
+
fontsize -= 1
|
157 |
+
fnt = get_font(fontsize)
|
158 |
+
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
159 |
+
|
160 |
+
if not line.is_active:
|
161 |
+
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
162 |
+
|
163 |
+
draw_y += line.size[1] + line_spacing
|
164 |
+
|
165 |
+
fontsize = (width + height) // 25
|
166 |
+
line_spacing = fontsize // 2
|
167 |
+
|
168 |
+
fnt = get_font(fontsize)
|
169 |
+
|
170 |
+
color_active = (0, 0, 0)
|
171 |
+
color_inactive = (153, 153, 153)
|
172 |
+
|
173 |
+
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
|
174 |
+
|
175 |
+
cols = im.width // width
|
176 |
+
rows = im.height // height
|
177 |
+
|
178 |
+
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
|
179 |
+
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
|
180 |
+
|
181 |
+
calc_img = Image.new("RGB", (1, 1), "white")
|
182 |
+
calc_d = ImageDraw.Draw(calc_img)
|
183 |
+
|
184 |
+
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
|
185 |
+
items = [] + texts
|
186 |
+
texts.clear()
|
187 |
+
|
188 |
+
for line in items:
|
189 |
+
wrapped = wrap(calc_d, line.text, fnt, allowed_width)
|
190 |
+
texts += [GridAnnotation(x, line.is_active) for x in wrapped]
|
191 |
+
|
192 |
+
for line in texts:
|
193 |
+
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
|
194 |
+
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
195 |
+
line.allowed_width = allowed_width
|
196 |
+
|
197 |
+
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
198 |
+
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
|
199 |
+
|
200 |
+
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
201 |
+
|
202 |
+
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
|
203 |
+
|
204 |
+
for row in range(rows):
|
205 |
+
for col in range(cols):
|
206 |
+
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
|
207 |
+
result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
|
208 |
+
|
209 |
+
d = ImageDraw.Draw(result)
|
210 |
+
|
211 |
+
for col in range(cols):
|
212 |
+
x = pad_left + (width + margin) * col + width / 2
|
213 |
+
y = pad_top / 2 - hor_text_heights[col] / 2
|
214 |
+
|
215 |
+
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
|
216 |
+
|
217 |
+
for row in range(rows):
|
218 |
+
x = pad_left / 2
|
219 |
+
y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
|
220 |
+
|
221 |
+
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
|
222 |
+
|
223 |
+
return result
|
224 |
+
|
225 |
+
|
226 |
+
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
227 |
+
prompts = all_prompts[1:]
|
228 |
+
boundary = math.ceil(len(prompts) / 2)
|
229 |
+
|
230 |
+
prompts_horiz = prompts[:boundary]
|
231 |
+
prompts_vert = prompts[boundary:]
|
232 |
+
|
233 |
+
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
234 |
+
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
235 |
+
|
236 |
+
return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
|
237 |
+
|
238 |
+
|
239 |
+
def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
240 |
+
"""
|
241 |
+
Resizes an image with the specified resize_mode, width, and height.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
resize_mode: The mode to use when resizing the image.
|
245 |
+
0: Resize the image to the specified width and height.
|
246 |
+
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
247 |
+
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
248 |
+
im: The image to resize.
|
249 |
+
width: The width to resize the image to.
|
250 |
+
height: The height to resize the image to.
|
251 |
+
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
|
252 |
+
"""
|
253 |
+
|
254 |
+
upscaler_name = upscaler_name or opts.upscaler_for_img2img
|
255 |
+
|
256 |
+
def resize(im, w, h):
|
257 |
+
if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
|
258 |
+
return im.resize((w, h), resample=LANCZOS)
|
259 |
+
|
260 |
+
scale = max(w / im.width, h / im.height)
|
261 |
+
|
262 |
+
if scale > 1.0:
|
263 |
+
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
|
264 |
+
assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
|
265 |
+
|
266 |
+
upscaler = upscalers[0]
|
267 |
+
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
268 |
+
|
269 |
+
if im.width != w or im.height != h:
|
270 |
+
im = im.resize((w, h), resample=LANCZOS)
|
271 |
+
|
272 |
+
return im
|
273 |
+
|
274 |
+
if resize_mode == 0:
|
275 |
+
res = resize(im, width, height)
|
276 |
+
|
277 |
+
elif resize_mode == 1:
|
278 |
+
ratio = width / height
|
279 |
+
src_ratio = im.width / im.height
|
280 |
+
|
281 |
+
src_w = width if ratio > src_ratio else im.width * height // im.height
|
282 |
+
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
283 |
+
|
284 |
+
resized = resize(im, src_w, src_h)
|
285 |
+
res = Image.new("RGB", (width, height))
|
286 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
287 |
+
|
288 |
+
else:
|
289 |
+
ratio = width / height
|
290 |
+
src_ratio = im.width / im.height
|
291 |
+
|
292 |
+
src_w = width if ratio < src_ratio else im.width * height // im.height
|
293 |
+
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
294 |
+
|
295 |
+
resized = resize(im, src_w, src_h)
|
296 |
+
res = Image.new("RGB", (width, height))
|
297 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
298 |
+
|
299 |
+
if ratio < src_ratio:
|
300 |
+
fill_height = height // 2 - src_h // 2
|
301 |
+
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
302 |
+
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
303 |
+
elif ratio > src_ratio:
|
304 |
+
fill_width = width // 2 - src_w // 2
|
305 |
+
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
306 |
+
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
307 |
+
|
308 |
+
return res
|
309 |
+
|
310 |
+
|
311 |
+
invalid_filename_chars = '<>:"/\\|?*\n'
|
312 |
+
invalid_filename_prefix = ' '
|
313 |
+
invalid_filename_postfix = ' .'
|
314 |
+
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
315 |
+
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
316 |
+
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
317 |
+
max_filename_part_length = 128
|
318 |
+
|
319 |
+
|
320 |
+
def sanitize_filename_part(text, replace_spaces=True):
|
321 |
+
if text is None:
|
322 |
+
return None
|
323 |
+
|
324 |
+
if replace_spaces:
|
325 |
+
text = text.replace(' ', '_')
|
326 |
+
|
327 |
+
text = text.translate({ord(x): '_' for x in invalid_filename_chars})
|
328 |
+
text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
|
329 |
+
text = text.rstrip(invalid_filename_postfix)
|
330 |
+
return text
|
331 |
+
|
332 |
+
|
333 |
+
class FilenameGenerator:
|
334 |
+
replacements = {
|
335 |
+
'seed': lambda self: self.seed if self.seed is not None else '',
|
336 |
+
'steps': lambda self: self.p and self.p.steps,
|
337 |
+
'cfg': lambda self: self.p and self.p.cfg_scale,
|
338 |
+
'width': lambda self: self.image.width,
|
339 |
+
'height': lambda self: self.image.height,
|
340 |
+
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
341 |
+
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
|
342 |
+
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
343 |
+
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
|
344 |
+
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
345 |
+
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
346 |
+
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
347 |
+
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
|
348 |
+
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
349 |
+
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
350 |
+
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
351 |
+
'prompt_words': lambda self: self.prompt_words(),
|
352 |
+
}
|
353 |
+
default_time_format = '%Y%m%d%H%M%S'
|
354 |
+
|
355 |
+
def __init__(self, p, seed, prompt, image):
|
356 |
+
self.p = p
|
357 |
+
self.seed = seed
|
358 |
+
self.prompt = prompt
|
359 |
+
self.image = image
|
360 |
+
|
361 |
+
def prompt_no_style(self):
|
362 |
+
if self.p is None or self.prompt is None:
|
363 |
+
return None
|
364 |
+
|
365 |
+
prompt_no_style = self.prompt
|
366 |
+
for style in shared.prompt_styles.get_style_prompts(self.p.styles):
|
367 |
+
if len(style) > 0:
|
368 |
+
for part in style.split("{prompt}"):
|
369 |
+
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
|
370 |
+
|
371 |
+
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
|
372 |
+
|
373 |
+
return sanitize_filename_part(prompt_no_style, replace_spaces=False)
|
374 |
+
|
375 |
+
def prompt_words(self):
|
376 |
+
words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
|
377 |
+
if len(words) == 0:
|
378 |
+
words = ["empty"]
|
379 |
+
return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
|
380 |
+
|
381 |
+
def datetime(self, *args):
|
382 |
+
time_datetime = datetime.datetime.now()
|
383 |
+
|
384 |
+
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
385 |
+
try:
|
386 |
+
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
387 |
+
except pytz.exceptions.UnknownTimeZoneError as _:
|
388 |
+
time_zone = None
|
389 |
+
|
390 |
+
time_zone_time = time_datetime.astimezone(time_zone)
|
391 |
+
try:
|
392 |
+
formatted_time = time_zone_time.strftime(time_format)
|
393 |
+
except (ValueError, TypeError) as _:
|
394 |
+
formatted_time = time_zone_time.strftime(self.default_time_format)
|
395 |
+
|
396 |
+
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
397 |
+
|
398 |
+
def apply(self, x):
|
399 |
+
res = ''
|
400 |
+
|
401 |
+
for m in re_pattern.finditer(x):
|
402 |
+
text, pattern = m.groups()
|
403 |
+
res += text
|
404 |
+
|
405 |
+
if pattern is None:
|
406 |
+
continue
|
407 |
+
|
408 |
+
pattern_args = []
|
409 |
+
while True:
|
410 |
+
m = re_pattern_arg.match(pattern)
|
411 |
+
if m is None:
|
412 |
+
break
|
413 |
+
|
414 |
+
pattern, arg = m.groups()
|
415 |
+
pattern_args.insert(0, arg)
|
416 |
+
|
417 |
+
fun = self.replacements.get(pattern.lower())
|
418 |
+
if fun is not None:
|
419 |
+
try:
|
420 |
+
replacement = fun(self, *pattern_args)
|
421 |
+
except Exception:
|
422 |
+
replacement = None
|
423 |
+
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
|
424 |
+
print(traceback.format_exc(), file=sys.stderr)
|
425 |
+
|
426 |
+
if replacement is not None:
|
427 |
+
res += str(replacement)
|
428 |
+
continue
|
429 |
+
|
430 |
+
res += f'[{pattern}]'
|
431 |
+
|
432 |
+
return res
|
433 |
+
|
434 |
+
|
435 |
+
def get_next_sequence_number(path, basename):
|
436 |
+
"""
|
437 |
+
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
438 |
+
|
439 |
+
The sequence starts at 0.
|
440 |
+
"""
|
441 |
+
result = -1
|
442 |
+
if basename != '':
|
443 |
+
basename = basename + "-"
|
444 |
+
|
445 |
+
prefix_length = len(basename)
|
446 |
+
for p in os.listdir(path):
|
447 |
+
if p.startswith(basename):
|
448 |
+
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
449 |
+
try:
|
450 |
+
result = max(int(l[0]), result)
|
451 |
+
except ValueError:
|
452 |
+
pass
|
453 |
+
|
454 |
+
return result + 1
|
455 |
+
|
456 |
+
|
457 |
+
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
458 |
+
"""Save an image.
|
459 |
+
|
460 |
+
Args:
|
461 |
+
image (`PIL.Image`):
|
462 |
+
The image to be saved.
|
463 |
+
path (`str`):
|
464 |
+
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
|
465 |
+
basename (`str`):
|
466 |
+
The base filename which will be applied to `filename pattern`.
|
467 |
+
seed, prompt, short_filename,
|
468 |
+
extension (`str`):
|
469 |
+
Image file extension, default is `png`.
|
470 |
+
pngsectionname (`str`):
|
471 |
+
Specify the name of the section which `info` will be saved in.
|
472 |
+
info (`str` or `PngImagePlugin.iTXt`):
|
473 |
+
PNG info chunks.
|
474 |
+
existing_info (`dict`):
|
475 |
+
Additional PNG info. `existing_info == {pngsectionname: info, ...}`
|
476 |
+
no_prompt:
|
477 |
+
TODO I don't know its meaning.
|
478 |
+
p (`StableDiffusionProcessing`)
|
479 |
+
forced_filename (`str`):
|
480 |
+
If specified, `basename` and filename pattern will be ignored.
|
481 |
+
save_to_dirs (bool):
|
482 |
+
If true, the image will be saved into a subdirectory of `path`.
|
483 |
+
|
484 |
+
Returns: (fullfn, txt_fullfn)
|
485 |
+
fullfn (`str`):
|
486 |
+
The full path of the saved imaged.
|
487 |
+
txt_fullfn (`str` or None):
|
488 |
+
If a text file is saved for this image, this will be its full path. Otherwise None.
|
489 |
+
"""
|
490 |
+
namegen = FilenameGenerator(p, seed, prompt, image)
|
491 |
+
|
492 |
+
if save_to_dirs is None:
|
493 |
+
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
494 |
+
|
495 |
+
if save_to_dirs:
|
496 |
+
dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
|
497 |
+
path = os.path.join(path, dirname)
|
498 |
+
|
499 |
+
os.makedirs(path, exist_ok=True)
|
500 |
+
|
501 |
+
if forced_filename is None:
|
502 |
+
if short_filename or seed is None:
|
503 |
+
file_decoration = ""
|
504 |
+
elif opts.save_to_dirs:
|
505 |
+
file_decoration = opts.samples_filename_pattern or "[seed]"
|
506 |
+
else:
|
507 |
+
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
|
508 |
+
|
509 |
+
add_number = opts.save_images_add_number or file_decoration == ''
|
510 |
+
|
511 |
+
if file_decoration != "" and add_number:
|
512 |
+
file_decoration = "-" + file_decoration
|
513 |
+
|
514 |
+
file_decoration = namegen.apply(file_decoration) + suffix
|
515 |
+
|
516 |
+
if add_number:
|
517 |
+
basecount = get_next_sequence_number(path, basename)
|
518 |
+
fullfn = None
|
519 |
+
for i in range(500):
|
520 |
+
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
521 |
+
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
522 |
+
if not os.path.exists(fullfn):
|
523 |
+
break
|
524 |
+
else:
|
525 |
+
fullfn = os.path.join(path, f"{file_decoration}.{extension}")
|
526 |
+
else:
|
527 |
+
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
|
528 |
+
|
529 |
+
pnginfo = existing_info or {}
|
530 |
+
if info is not None:
|
531 |
+
pnginfo[pnginfo_section_name] = info
|
532 |
+
|
533 |
+
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
|
534 |
+
script_callbacks.before_image_saved_callback(params)
|
535 |
+
|
536 |
+
image = params.image
|
537 |
+
fullfn = params.filename
|
538 |
+
info = params.pnginfo.get(pnginfo_section_name, None)
|
539 |
+
|
540 |
+
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
541 |
+
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
542 |
+
temp_file_path = filename_without_extension + ".tmp"
|
543 |
+
image_format = Image.registered_extensions()[extension]
|
544 |
+
|
545 |
+
if extension.lower() == '.png':
|
546 |
+
pnginfo_data = PngImagePlugin.PngInfo()
|
547 |
+
if opts.enable_pnginfo:
|
548 |
+
for k, v in params.pnginfo.items():
|
549 |
+
pnginfo_data.add_text(k, str(v))
|
550 |
+
|
551 |
+
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
552 |
+
|
553 |
+
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
554 |
+
if image_to_save.mode == 'RGBA':
|
555 |
+
image_to_save = image_to_save.convert("RGB")
|
556 |
+
elif image_to_save.mode == 'I;16':
|
557 |
+
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
558 |
+
|
559 |
+
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
560 |
+
|
561 |
+
if opts.enable_pnginfo and info is not None:
|
562 |
+
exif_bytes = piexif.dump({
|
563 |
+
"Exif": {
|
564 |
+
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
|
565 |
+
},
|
566 |
+
})
|
567 |
+
|
568 |
+
piexif.insert(exif_bytes, temp_file_path)
|
569 |
+
else:
|
570 |
+
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
571 |
+
|
572 |
+
# atomically rename the file with correct extension
|
573 |
+
os.replace(temp_file_path, filename_without_extension + extension)
|
574 |
+
|
575 |
+
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
576 |
+
_atomically_save_image(image, fullfn_without_extension, extension)
|
577 |
+
|
578 |
+
image.already_saved_as = fullfn
|
579 |
+
|
580 |
+
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
|
581 |
+
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
|
582 |
+
ratio = image.width / image.height
|
583 |
+
|
584 |
+
if oversize and ratio > 1:
|
585 |
+
image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
|
586 |
+
elif oversize:
|
587 |
+
image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
|
588 |
+
|
589 |
+
try:
|
590 |
+
_atomically_save_image(image, fullfn_without_extension, ".jpg")
|
591 |
+
except Exception as e:
|
592 |
+
errors.display(e, "saving image as downscaled JPG")
|
593 |
+
|
594 |
+
if opts.save_txt and info is not None:
|
595 |
+
txt_fullfn = f"{fullfn_without_extension}.txt"
|
596 |
+
with open(txt_fullfn, "w", encoding="utf8") as file:
|
597 |
+
file.write(info + "\n")
|
598 |
+
else:
|
599 |
+
txt_fullfn = None
|
600 |
+
|
601 |
+
script_callbacks.image_saved_callback(params)
|
602 |
+
|
603 |
+
return fullfn, txt_fullfn
|
604 |
+
|
605 |
+
|
606 |
+
def read_info_from_image(image):
|
607 |
+
items = image.info or {}
|
608 |
+
|
609 |
+
geninfo = items.pop('parameters', None)
|
610 |
+
|
611 |
+
if "exif" in items:
|
612 |
+
exif = piexif.load(items["exif"])
|
613 |
+
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
614 |
+
try:
|
615 |
+
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
616 |
+
except ValueError:
|
617 |
+
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
618 |
+
|
619 |
+
if exif_comment:
|
620 |
+
items['exif comment'] = exif_comment
|
621 |
+
geninfo = exif_comment
|
622 |
+
|
623 |
+
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
624 |
+
'loop', 'background', 'timestamp', 'duration']:
|
625 |
+
items.pop(field, None)
|
626 |
+
|
627 |
+
if items.get("Software", None) == "NovelAI":
|
628 |
+
try:
|
629 |
+
json_info = json.loads(items["Comment"])
|
630 |
+
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
|
631 |
+
|
632 |
+
geninfo = f"""{items["Description"]}
|
633 |
+
Negative prompt: {json_info["uc"]}
|
634 |
+
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
635 |
+
except Exception:
|
636 |
+
print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
|
637 |
+
print(traceback.format_exc(), file=sys.stderr)
|
638 |
+
|
639 |
+
return geninfo, items
|
640 |
+
|
641 |
+
|
642 |
+
def image_data(data):
|
643 |
+
try:
|
644 |
+
image = Image.open(io.BytesIO(data))
|
645 |
+
textinfo, _ = read_info_from_image(image)
|
646 |
+
return textinfo, None
|
647 |
+
except Exception:
|
648 |
+
pass
|
649 |
+
|
650 |
+
try:
|
651 |
+
text = data.decode('utf8')
|
652 |
+
assert len(text) < 10000
|
653 |
+
return text, None
|
654 |
+
|
655 |
+
except Exception:
|
656 |
+
pass
|
657 |
+
|
658 |
+
return '', None
|
659 |
+
|
660 |
+
|
661 |
+
def flatten(img, bgcolor):
|
662 |
+
"""replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
|
663 |
+
|
664 |
+
if img.mode == "RGBA":
|
665 |
+
background = Image.new('RGBA', img.size, bgcolor)
|
666 |
+
background.paste(img, mask=img)
|
667 |
+
img = background
|
668 |
+
|
669 |
+
return img.convert('RGB')
|
sd/stable-diffusion-webui/modules/img2img.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
8 |
+
|
9 |
+
from modules import devices, sd_samplers
|
10 |
+
from modules.generation_parameters_copypaste import create_override_settings_dict
|
11 |
+
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
12 |
+
from modules.shared import opts, state
|
13 |
+
import modules.shared as shared
|
14 |
+
import modules.processing as processing
|
15 |
+
from modules.ui import plaintext_to_html
|
16 |
+
import modules.images as images
|
17 |
+
import modules.scripts
|
18 |
+
|
19 |
+
|
20 |
+
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
21 |
+
processing.fix_seed(p)
|
22 |
+
|
23 |
+
images = shared.listfiles(input_dir)
|
24 |
+
|
25 |
+
is_inpaint_batch = False
|
26 |
+
if inpaint_mask_dir:
|
27 |
+
inpaint_masks = shared.listfiles(inpaint_mask_dir)
|
28 |
+
is_inpaint_batch = len(inpaint_masks) > 0
|
29 |
+
if is_inpaint_batch:
|
30 |
+
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
31 |
+
|
32 |
+
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
33 |
+
|
34 |
+
save_normally = output_dir == ''
|
35 |
+
|
36 |
+
p.do_not_save_grid = True
|
37 |
+
p.do_not_save_samples = not save_normally
|
38 |
+
|
39 |
+
state.job_count = len(images) * p.n_iter
|
40 |
+
|
41 |
+
for i, image in enumerate(images):
|
42 |
+
state.job = f"{i+1} out of {len(images)}"
|
43 |
+
if state.skipped:
|
44 |
+
state.skipped = False
|
45 |
+
|
46 |
+
if state.interrupted:
|
47 |
+
break
|
48 |
+
|
49 |
+
img = Image.open(image)
|
50 |
+
# Use the EXIF orientation of photos taken by smartphones.
|
51 |
+
img = ImageOps.exif_transpose(img)
|
52 |
+
p.init_images = [img] * p.batch_size
|
53 |
+
|
54 |
+
if is_inpaint_batch:
|
55 |
+
# try to find corresponding mask for an image using simple filename matching
|
56 |
+
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
57 |
+
# if not found use first one ("same mask for all images" use-case)
|
58 |
+
if not mask_image_path in inpaint_masks:
|
59 |
+
mask_image_path = inpaint_masks[0]
|
60 |
+
mask_image = Image.open(mask_image_path)
|
61 |
+
p.image_mask = mask_image
|
62 |
+
|
63 |
+
proc = modules.scripts.scripts_img2img.run(p, *args)
|
64 |
+
if proc is None:
|
65 |
+
proc = process_images(p)
|
66 |
+
|
67 |
+
for n, processed_image in enumerate(proc.images):
|
68 |
+
filename = os.path.basename(image)
|
69 |
+
|
70 |
+
if n > 0:
|
71 |
+
left, right = os.path.splitext(filename)
|
72 |
+
filename = f"{left}-{n}{right}"
|
73 |
+
|
74 |
+
if not save_normally:
|
75 |
+
os.makedirs(output_dir, exist_ok=True)
|
76 |
+
if processed_image.mode == 'RGBA':
|
77 |
+
processed_image = processed_image.convert("RGB")
|
78 |
+
processed_image.save(os.path.join(output_dir, filename))
|
79 |
+
|
80 |
+
|
81 |
+
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
82 |
+
override_settings = create_override_settings_dict(override_settings_texts)
|
83 |
+
|
84 |
+
is_batch = mode == 5
|
85 |
+
|
86 |
+
if mode == 0: # img2img
|
87 |
+
image = init_img.convert("RGB")
|
88 |
+
mask = None
|
89 |
+
elif mode == 1: # img2img sketch
|
90 |
+
image = sketch.convert("RGB")
|
91 |
+
mask = None
|
92 |
+
elif mode == 2: # inpaint
|
93 |
+
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
94 |
+
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
95 |
+
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
96 |
+
image = image.convert("RGB")
|
97 |
+
elif mode == 3: # inpaint sketch
|
98 |
+
image = inpaint_color_sketch
|
99 |
+
orig = inpaint_color_sketch_orig or inpaint_color_sketch
|
100 |
+
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
101 |
+
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
102 |
+
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
103 |
+
blur = ImageFilter.GaussianBlur(mask_blur)
|
104 |
+
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
105 |
+
image = image.convert("RGB")
|
106 |
+
elif mode == 4: # inpaint upload mask
|
107 |
+
image = init_img_inpaint
|
108 |
+
mask = init_mask_inpaint
|
109 |
+
else:
|
110 |
+
image = None
|
111 |
+
mask = None
|
112 |
+
|
113 |
+
# Use the EXIF orientation of photos taken by smartphones.
|
114 |
+
if image is not None:
|
115 |
+
image = ImageOps.exif_transpose(image)
|
116 |
+
|
117 |
+
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
118 |
+
|
119 |
+
p = StableDiffusionProcessingImg2Img(
|
120 |
+
sd_model=shared.sd_model,
|
121 |
+
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
|
122 |
+
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
|
123 |
+
prompt=prompt,
|
124 |
+
negative_prompt=negative_prompt,
|
125 |
+
styles=prompt_styles,
|
126 |
+
seed=seed,
|
127 |
+
subseed=subseed,
|
128 |
+
subseed_strength=subseed_strength,
|
129 |
+
seed_resize_from_h=seed_resize_from_h,
|
130 |
+
seed_resize_from_w=seed_resize_from_w,
|
131 |
+
seed_enable_extras=seed_enable_extras,
|
132 |
+
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
|
133 |
+
batch_size=batch_size,
|
134 |
+
n_iter=n_iter,
|
135 |
+
steps=steps,
|
136 |
+
cfg_scale=cfg_scale,
|
137 |
+
width=width,
|
138 |
+
height=height,
|
139 |
+
restore_faces=restore_faces,
|
140 |
+
tiling=tiling,
|
141 |
+
init_images=[image],
|
142 |
+
mask=mask,
|
143 |
+
mask_blur=mask_blur,
|
144 |
+
inpainting_fill=inpainting_fill,
|
145 |
+
resize_mode=resize_mode,
|
146 |
+
denoising_strength=denoising_strength,
|
147 |
+
image_cfg_scale=image_cfg_scale,
|
148 |
+
inpaint_full_res=inpaint_full_res,
|
149 |
+
inpaint_full_res_padding=inpaint_full_res_padding,
|
150 |
+
inpainting_mask_invert=inpainting_mask_invert,
|
151 |
+
override_settings=override_settings,
|
152 |
+
)
|
153 |
+
|
154 |
+
p.scripts = modules.scripts.scripts_txt2img
|
155 |
+
p.script_args = args
|
156 |
+
|
157 |
+
if shared.cmd_opts.enable_console_prompts:
|
158 |
+
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
159 |
+
|
160 |
+
p.extra_generation_params["Mask blur"] = mask_blur
|
161 |
+
|
162 |
+
if is_batch:
|
163 |
+
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
164 |
+
|
165 |
+
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
|
166 |
+
|
167 |
+
processed = Processed(p, [], p.seed, "")
|
168 |
+
else:
|
169 |
+
processed = modules.scripts.scripts_img2img.run(p, *args)
|
170 |
+
if processed is None:
|
171 |
+
processed = process_images(p)
|
172 |
+
|
173 |
+
p.close()
|
174 |
+
|
175 |
+
shared.total_tqdm.clear()
|
176 |
+
|
177 |
+
generation_info_js = processed.js()
|
178 |
+
if opts.samples_log_stdout:
|
179 |
+
print(generation_info_js)
|
180 |
+
|
181 |
+
if opts.do_not_show_images:
|
182 |
+
processed.images = []
|
183 |
+
|
184 |
+
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
sd/stable-diffusion-webui/modules/import_hook.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
|
4 |
+
if "--xformers" not in "".join(sys.argv):
|
5 |
+
sys.modules["xformers"] = None
|
sd/stable-diffusion-webui/modules/interrogate.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
from collections import namedtuple
|
5 |
+
from pathlib import Path
|
6 |
+
import re
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.hub
|
10 |
+
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms.functional import InterpolationMode
|
13 |
+
|
14 |
+
import modules.shared as shared
|
15 |
+
from modules import devices, paths, shared, lowvram, modelloader, errors
|
16 |
+
|
17 |
+
blip_image_eval_size = 384
|
18 |
+
clip_model_name = 'ViT-L/14'
|
19 |
+
|
20 |
+
Category = namedtuple("Category", ["name", "topn", "items"])
|
21 |
+
|
22 |
+
re_topn = re.compile(r"\.top(\d+)\.")
|
23 |
+
|
24 |
+
def category_types():
|
25 |
+
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
26 |
+
|
27 |
+
|
28 |
+
def download_default_clip_interrogate_categories(content_dir):
|
29 |
+
print("Downloading CLIP categories...")
|
30 |
+
|
31 |
+
tmpdir = content_dir + "_tmp"
|
32 |
+
category_types = ["artists", "flavors", "mediums", "movements"]
|
33 |
+
|
34 |
+
try:
|
35 |
+
os.makedirs(tmpdir)
|
36 |
+
for category_type in category_types:
|
37 |
+
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
38 |
+
os.rename(tmpdir, content_dir)
|
39 |
+
|
40 |
+
except Exception as e:
|
41 |
+
errors.display(e, "downloading default CLIP interrogate categories")
|
42 |
+
finally:
|
43 |
+
if os.path.exists(tmpdir):
|
44 |
+
os.remove(tmpdir)
|
45 |
+
|
46 |
+
|
47 |
+
class InterrogateModels:
|
48 |
+
blip_model = None
|
49 |
+
clip_model = None
|
50 |
+
clip_preprocess = None
|
51 |
+
dtype = None
|
52 |
+
running_on_cpu = None
|
53 |
+
|
54 |
+
def __init__(self, content_dir):
|
55 |
+
self.loaded_categories = None
|
56 |
+
self.skip_categories = []
|
57 |
+
self.content_dir = content_dir
|
58 |
+
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
59 |
+
|
60 |
+
def categories(self):
|
61 |
+
if not os.path.exists(self.content_dir):
|
62 |
+
download_default_clip_interrogate_categories(self.content_dir)
|
63 |
+
|
64 |
+
if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
|
65 |
+
return self.loaded_categories
|
66 |
+
|
67 |
+
self.loaded_categories = []
|
68 |
+
|
69 |
+
if os.path.exists(self.content_dir):
|
70 |
+
self.skip_categories = shared.opts.interrogate_clip_skip_categories
|
71 |
+
category_types = []
|
72 |
+
for filename in Path(self.content_dir).glob('*.txt'):
|
73 |
+
category_types.append(filename.stem)
|
74 |
+
if filename.stem in self.skip_categories:
|
75 |
+
continue
|
76 |
+
m = re_topn.search(filename.stem)
|
77 |
+
topn = 1 if m is None else int(m.group(1))
|
78 |
+
with open(filename, "r", encoding="utf8") as file:
|
79 |
+
lines = [x.strip() for x in file.readlines()]
|
80 |
+
|
81 |
+
self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
|
82 |
+
|
83 |
+
return self.loaded_categories
|
84 |
+
|
85 |
+
def create_fake_fairscale(self):
|
86 |
+
class FakeFairscale:
|
87 |
+
def checkpoint_wrapper(self):
|
88 |
+
pass
|
89 |
+
|
90 |
+
sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
|
91 |
+
|
92 |
+
def load_blip_model(self):
|
93 |
+
self.create_fake_fairscale()
|
94 |
+
import models.blip
|
95 |
+
|
96 |
+
files = modelloader.load_models(
|
97 |
+
model_path=os.path.join(paths.models_path, "BLIP"),
|
98 |
+
model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
|
99 |
+
ext_filter=[".pth"],
|
100 |
+
download_name='model_base_caption_capfilt_large.pth',
|
101 |
+
)
|
102 |
+
|
103 |
+
blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
104 |
+
blip_model.eval()
|
105 |
+
|
106 |
+
return blip_model
|
107 |
+
|
108 |
+
def load_clip_model(self):
|
109 |
+
import clip
|
110 |
+
|
111 |
+
if self.running_on_cpu:
|
112 |
+
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
|
113 |
+
else:
|
114 |
+
model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
|
115 |
+
|
116 |
+
model.eval()
|
117 |
+
model = model.to(devices.device_interrogate)
|
118 |
+
|
119 |
+
return model, preprocess
|
120 |
+
|
121 |
+
def load(self):
|
122 |
+
if self.blip_model is None:
|
123 |
+
self.blip_model = self.load_blip_model()
|
124 |
+
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
125 |
+
self.blip_model = self.blip_model.half()
|
126 |
+
|
127 |
+
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
128 |
+
|
129 |
+
if self.clip_model is None:
|
130 |
+
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
131 |
+
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
132 |
+
self.clip_model = self.clip_model.half()
|
133 |
+
|
134 |
+
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
135 |
+
|
136 |
+
self.dtype = next(self.clip_model.parameters()).dtype
|
137 |
+
|
138 |
+
def send_clip_to_ram(self):
|
139 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
140 |
+
if self.clip_model is not None:
|
141 |
+
self.clip_model = self.clip_model.to(devices.cpu)
|
142 |
+
|
143 |
+
def send_blip_to_ram(self):
|
144 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
145 |
+
if self.blip_model is not None:
|
146 |
+
self.blip_model = self.blip_model.to(devices.cpu)
|
147 |
+
|
148 |
+
def unload(self):
|
149 |
+
self.send_clip_to_ram()
|
150 |
+
self.send_blip_to_ram()
|
151 |
+
|
152 |
+
devices.torch_gc()
|
153 |
+
|
154 |
+
def rank(self, image_features, text_array, top_count=1):
|
155 |
+
import clip
|
156 |
+
|
157 |
+
devices.torch_gc()
|
158 |
+
|
159 |
+
if shared.opts.interrogate_clip_dict_limit != 0:
|
160 |
+
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
161 |
+
|
162 |
+
top_count = min(top_count, len(text_array))
|
163 |
+
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
|
164 |
+
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
165 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
166 |
+
|
167 |
+
similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
|
168 |
+
for i in range(image_features.shape[0]):
|
169 |
+
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
170 |
+
similarity /= image_features.shape[0]
|
171 |
+
|
172 |
+
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
173 |
+
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
174 |
+
|
175 |
+
def generate_caption(self, pil_image):
|
176 |
+
gpu_image = transforms.Compose([
|
177 |
+
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
178 |
+
transforms.ToTensor(),
|
179 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
180 |
+
])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
181 |
+
|
182 |
+
with torch.no_grad():
|
183 |
+
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
184 |
+
|
185 |
+
return caption[0]
|
186 |
+
|
187 |
+
def interrogate(self, pil_image):
|
188 |
+
res = ""
|
189 |
+
shared.state.begin()
|
190 |
+
shared.state.job = 'interrogate'
|
191 |
+
try:
|
192 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
193 |
+
lowvram.send_everything_to_cpu()
|
194 |
+
devices.torch_gc()
|
195 |
+
|
196 |
+
self.load()
|
197 |
+
|
198 |
+
caption = self.generate_caption(pil_image)
|
199 |
+
self.send_blip_to_ram()
|
200 |
+
devices.torch_gc()
|
201 |
+
|
202 |
+
res = caption
|
203 |
+
|
204 |
+
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
205 |
+
|
206 |
+
with torch.no_grad(), devices.autocast():
|
207 |
+
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
208 |
+
|
209 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
210 |
+
|
211 |
+
for name, topn, items in self.categories():
|
212 |
+
matches = self.rank(image_features, items, top_count=topn)
|
213 |
+
for match, score in matches:
|
214 |
+
if shared.opts.interrogate_return_ranks:
|
215 |
+
res += f", ({match}:{score/100:.3f})"
|
216 |
+
else:
|
217 |
+
res += ", " + match
|
218 |
+
|
219 |
+
except Exception:
|
220 |
+
print("Error interrogating", file=sys.stderr)
|
221 |
+
print(traceback.format_exc(), file=sys.stderr)
|
222 |
+
res += "<error>"
|
223 |
+
|
224 |
+
self.unload()
|
225 |
+
shared.state.end()
|
226 |
+
|
227 |
+
return res
|
sd/stable-diffusion-webui/modules/localization.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
|
7 |
+
localizations = {}
|
8 |
+
|
9 |
+
|
10 |
+
def list_localizations(dirname):
|
11 |
+
localizations.clear()
|
12 |
+
|
13 |
+
for file in os.listdir(dirname):
|
14 |
+
fn, ext = os.path.splitext(file)
|
15 |
+
if ext.lower() != ".json":
|
16 |
+
continue
|
17 |
+
|
18 |
+
localizations[fn] = os.path.join(dirname, file)
|
19 |
+
|
20 |
+
from modules import scripts
|
21 |
+
for file in scripts.list_scripts("localizations", ".json"):
|
22 |
+
fn, ext = os.path.splitext(file.filename)
|
23 |
+
localizations[fn] = file.path
|
24 |
+
|
25 |
+
|
26 |
+
def localization_js(current_localization_name):
|
27 |
+
fn = localizations.get(current_localization_name, None)
|
28 |
+
data = {}
|
29 |
+
if fn is not None:
|
30 |
+
try:
|
31 |
+
with open(fn, "r", encoding="utf8") as file:
|
32 |
+
data = json.load(file)
|
33 |
+
except Exception:
|
34 |
+
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
35 |
+
print(traceback.format_exc(), file=sys.stderr)
|
36 |
+
|
37 |
+
return f"var localization = {json.dumps(data)}\n"
|
sd/stable-diffusion-webui/modules/lowvram.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from modules import devices
|
3 |
+
|
4 |
+
module_in_gpu = None
|
5 |
+
cpu = torch.device("cpu")
|
6 |
+
|
7 |
+
|
8 |
+
def send_everything_to_cpu():
|
9 |
+
global module_in_gpu
|
10 |
+
|
11 |
+
if module_in_gpu is not None:
|
12 |
+
module_in_gpu.to(cpu)
|
13 |
+
|
14 |
+
module_in_gpu = None
|
15 |
+
|
16 |
+
|
17 |
+
def setup_for_low_vram(sd_model, use_medvram):
|
18 |
+
parents = {}
|
19 |
+
|
20 |
+
def send_me_to_gpu(module, _):
|
21 |
+
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
|
22 |
+
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
|
23 |
+
be in CPU
|
24 |
+
"""
|
25 |
+
global module_in_gpu
|
26 |
+
|
27 |
+
module = parents.get(module, module)
|
28 |
+
|
29 |
+
if module_in_gpu == module:
|
30 |
+
return
|
31 |
+
|
32 |
+
if module_in_gpu is not None:
|
33 |
+
module_in_gpu.to(cpu)
|
34 |
+
|
35 |
+
module.to(devices.device)
|
36 |
+
module_in_gpu = module
|
37 |
+
|
38 |
+
# see below for register_forward_pre_hook;
|
39 |
+
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
40 |
+
# useless here, and we just replace those methods
|
41 |
+
|
42 |
+
first_stage_model = sd_model.first_stage_model
|
43 |
+
first_stage_model_encode = sd_model.first_stage_model.encode
|
44 |
+
first_stage_model_decode = sd_model.first_stage_model.decode
|
45 |
+
|
46 |
+
def first_stage_model_encode_wrap(x):
|
47 |
+
send_me_to_gpu(first_stage_model, None)
|
48 |
+
return first_stage_model_encode(x)
|
49 |
+
|
50 |
+
def first_stage_model_decode_wrap(z):
|
51 |
+
send_me_to_gpu(first_stage_model, None)
|
52 |
+
return first_stage_model_decode(z)
|
53 |
+
|
54 |
+
# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
|
55 |
+
if hasattr(sd_model.cond_stage_model, 'model'):
|
56 |
+
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
|
57 |
+
|
58 |
+
# remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
|
59 |
+
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
60 |
+
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
|
61 |
+
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
|
62 |
+
sd_model.to(devices.device)
|
63 |
+
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
|
64 |
+
|
65 |
+
# register hooks for those the first three models
|
66 |
+
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
67 |
+
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
68 |
+
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
69 |
+
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
70 |
+
if sd_model.depth_model:
|
71 |
+
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
72 |
+
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
73 |
+
|
74 |
+
if hasattr(sd_model.cond_stage_model, 'model'):
|
75 |
+
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
|
76 |
+
del sd_model.cond_stage_model.transformer
|
77 |
+
|
78 |
+
if use_medvram:
|
79 |
+
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
80 |
+
else:
|
81 |
+
diff_model = sd_model.model.diffusion_model
|
82 |
+
|
83 |
+
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
|
84 |
+
# so that only one of them is in GPU at a time
|
85 |
+
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
86 |
+
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
87 |
+
sd_model.model.to(devices.device)
|
88 |
+
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
89 |
+
|
90 |
+
# install hooks for bits of third model
|
91 |
+
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
92 |
+
for block in diff_model.input_blocks:
|
93 |
+
block.register_forward_pre_hook(send_me_to_gpu)
|
94 |
+
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
|
95 |
+
for block in diff_model.output_blocks:
|
96 |
+
block.register_forward_pre_hook(send_me_to_gpu)
|
sd/stable-diffusion-webui/modules/mac_specific.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from modules import paths
|
3 |
+
from modules.sd_hijack_utils import CondFunc
|
4 |
+
from packaging import version
|
5 |
+
|
6 |
+
|
7 |
+
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
8 |
+
# check `getattr` and try it for compatibility
|
9 |
+
def check_for_mps() -> bool:
|
10 |
+
if not getattr(torch, 'has_mps', False):
|
11 |
+
return False
|
12 |
+
try:
|
13 |
+
torch.zeros(1).to(torch.device("mps"))
|
14 |
+
return True
|
15 |
+
except Exception:
|
16 |
+
return False
|
17 |
+
has_mps = check_for_mps()
|
18 |
+
|
19 |
+
|
20 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
21 |
+
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
22 |
+
if input.device.type == 'mps':
|
23 |
+
output_dtype = kwargs.get('dtype', input.dtype)
|
24 |
+
if output_dtype == torch.int64:
|
25 |
+
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
26 |
+
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
27 |
+
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
28 |
+
return cumsum_func(input, *args, **kwargs)
|
29 |
+
|
30 |
+
|
31 |
+
if has_mps:
|
32 |
+
# MPS fix for randn in torchsde
|
33 |
+
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
34 |
+
|
35 |
+
if version.parse(torch.__version__) < version.parse("1.13"):
|
36 |
+
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
37 |
+
|
38 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
39 |
+
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
40 |
+
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
41 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
42 |
+
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
43 |
+
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
44 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
45 |
+
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
46 |
+
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
47 |
+
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
48 |
+
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
49 |
+
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
50 |
+
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
51 |
+
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
52 |
+
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
53 |
+
|
sd/stable-diffusion-webui/modules/masking.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageFilter, ImageOps
|
2 |
+
|
3 |
+
|
4 |
+
def get_crop_region(mask, pad=0):
|
5 |
+
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
6 |
+
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
7 |
+
|
8 |
+
h, w = mask.shape
|
9 |
+
|
10 |
+
crop_left = 0
|
11 |
+
for i in range(w):
|
12 |
+
if not (mask[:, i] == 0).all():
|
13 |
+
break
|
14 |
+
crop_left += 1
|
15 |
+
|
16 |
+
crop_right = 0
|
17 |
+
for i in reversed(range(w)):
|
18 |
+
if not (mask[:, i] == 0).all():
|
19 |
+
break
|
20 |
+
crop_right += 1
|
21 |
+
|
22 |
+
crop_top = 0
|
23 |
+
for i in range(h):
|
24 |
+
if not (mask[i] == 0).all():
|
25 |
+
break
|
26 |
+
crop_top += 1
|
27 |
+
|
28 |
+
crop_bottom = 0
|
29 |
+
for i in reversed(range(h)):
|
30 |
+
if not (mask[i] == 0).all():
|
31 |
+
break
|
32 |
+
crop_bottom += 1
|
33 |
+
|
34 |
+
return (
|
35 |
+
int(max(crop_left-pad, 0)),
|
36 |
+
int(max(crop_top-pad, 0)),
|
37 |
+
int(min(w - crop_right + pad, w)),
|
38 |
+
int(min(h - crop_bottom + pad, h))
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
|
43 |
+
"""expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
|
44 |
+
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
|
45 |
+
|
46 |
+
x1, y1, x2, y2 = crop_region
|
47 |
+
|
48 |
+
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
49 |
+
ratio_processing = processing_width / processing_height
|
50 |
+
|
51 |
+
if ratio_crop_region > ratio_processing:
|
52 |
+
desired_height = (x2 - x1) / ratio_processing
|
53 |
+
desired_height_diff = int(desired_height - (y2-y1))
|
54 |
+
y1 -= desired_height_diff//2
|
55 |
+
y2 += desired_height_diff - desired_height_diff//2
|
56 |
+
if y2 >= image_height:
|
57 |
+
diff = y2 - image_height
|
58 |
+
y2 -= diff
|
59 |
+
y1 -= diff
|
60 |
+
if y1 < 0:
|
61 |
+
y2 -= y1
|
62 |
+
y1 -= y1
|
63 |
+
if y2 >= image_height:
|
64 |
+
y2 = image_height
|
65 |
+
else:
|
66 |
+
desired_width = (y2 - y1) * ratio_processing
|
67 |
+
desired_width_diff = int(desired_width - (x2-x1))
|
68 |
+
x1 -= desired_width_diff//2
|
69 |
+
x2 += desired_width_diff - desired_width_diff//2
|
70 |
+
if x2 >= image_width:
|
71 |
+
diff = x2 - image_width
|
72 |
+
x2 -= diff
|
73 |
+
x1 -= diff
|
74 |
+
if x1 < 0:
|
75 |
+
x2 -= x1
|
76 |
+
x1 -= x1
|
77 |
+
if x2 >= image_width:
|
78 |
+
x2 = image_width
|
79 |
+
|
80 |
+
return x1, y1, x2, y2
|
81 |
+
|
82 |
+
|
83 |
+
def fill(image, mask):
|
84 |
+
"""fills masked regions with colors from image using blur. Not extremely effective."""
|
85 |
+
|
86 |
+
image_mod = Image.new('RGBA', (image.width, image.height))
|
87 |
+
|
88 |
+
image_masked = Image.new('RGBa', (image.width, image.height))
|
89 |
+
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
|
90 |
+
|
91 |
+
image_masked = image_masked.convert('RGBa')
|
92 |
+
|
93 |
+
for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
|
94 |
+
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
|
95 |
+
for _ in range(repeats):
|
96 |
+
image_mod.alpha_composite(blurred)
|
97 |
+
|
98 |
+
return image_mod.convert("RGB")
|
99 |
+
|
sd/stable-diffusion-webui/modules/memmon.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import time
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class MemUsageMonitor(threading.Thread):
|
9 |
+
run_flag = None
|
10 |
+
device = None
|
11 |
+
disabled = False
|
12 |
+
opts = None
|
13 |
+
data = None
|
14 |
+
|
15 |
+
def __init__(self, name, device, opts):
|
16 |
+
threading.Thread.__init__(self)
|
17 |
+
self.name = name
|
18 |
+
self.device = device
|
19 |
+
self.opts = opts
|
20 |
+
|
21 |
+
self.daemon = True
|
22 |
+
self.run_flag = threading.Event()
|
23 |
+
self.data = defaultdict(int)
|
24 |
+
|
25 |
+
try:
|
26 |
+
torch.cuda.mem_get_info()
|
27 |
+
torch.cuda.memory_stats(self.device)
|
28 |
+
except Exception as e: # AMD or whatever
|
29 |
+
print(f"Warning: caught exception '{e}', memory monitor disabled")
|
30 |
+
self.disabled = True
|
31 |
+
|
32 |
+
def run(self):
|
33 |
+
if self.disabled:
|
34 |
+
return
|
35 |
+
|
36 |
+
while True:
|
37 |
+
self.run_flag.wait()
|
38 |
+
|
39 |
+
torch.cuda.reset_peak_memory_stats()
|
40 |
+
self.data.clear()
|
41 |
+
|
42 |
+
if self.opts.memmon_poll_rate <= 0:
|
43 |
+
self.run_flag.clear()
|
44 |
+
continue
|
45 |
+
|
46 |
+
self.data["min_free"] = torch.cuda.mem_get_info()[0]
|
47 |
+
|
48 |
+
while self.run_flag.is_set():
|
49 |
+
free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
|
50 |
+
self.data["min_free"] = min(self.data["min_free"], free)
|
51 |
+
|
52 |
+
time.sleep(1 / self.opts.memmon_poll_rate)
|
53 |
+
|
54 |
+
def dump_debug(self):
|
55 |
+
print(self, 'recorded data:')
|
56 |
+
for k, v in self.read().items():
|
57 |
+
print(k, -(v // -(1024 ** 2)))
|
58 |
+
|
59 |
+
print(self, 'raw torch memory stats:')
|
60 |
+
tm = torch.cuda.memory_stats(self.device)
|
61 |
+
for k, v in tm.items():
|
62 |
+
if 'bytes' not in k:
|
63 |
+
continue
|
64 |
+
print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
|
65 |
+
|
66 |
+
print(torch.cuda.memory_summary())
|
67 |
+
|
68 |
+
def monitor(self):
|
69 |
+
self.run_flag.set()
|
70 |
+
|
71 |
+
def read(self):
|
72 |
+
if not self.disabled:
|
73 |
+
free, total = torch.cuda.mem_get_info()
|
74 |
+
self.data["free"] = free
|
75 |
+
self.data["total"] = total
|
76 |
+
|
77 |
+
torch_stats = torch.cuda.memory_stats(self.device)
|
78 |
+
self.data["active"] = torch_stats["active.all.current"]
|
79 |
+
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
80 |
+
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
|
81 |
+
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
82 |
+
self.data["system_peak"] = total - self.data["min_free"]
|
83 |
+
|
84 |
+
return self.data
|
85 |
+
|
86 |
+
def stop(self):
|
87 |
+
self.run_flag.clear()
|
88 |
+
return self.read()
|
sd/stable-diffusion-webui/modules/modelloader.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import importlib
|
5 |
+
from urllib.parse import urlparse
|
6 |
+
|
7 |
+
from basicsr.utils.download_util import load_file_from_url
|
8 |
+
from modules import shared
|
9 |
+
from modules.upscaler import Upscaler
|
10 |
+
from modules.paths import script_path, models_path
|
11 |
+
|
12 |
+
|
13 |
+
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
|
14 |
+
"""
|
15 |
+
A one-and done loader to try finding the desired models in specified directories.
|
16 |
+
|
17 |
+
@param download_name: Specify to download from model_url immediately.
|
18 |
+
@param model_url: If no other models are found, this will be downloaded on upscale.
|
19 |
+
@param model_path: The location to store/find models in.
|
20 |
+
@param command_path: A command-line argument to search for models in first.
|
21 |
+
@param ext_filter: An optional list of filename extensions to filter by
|
22 |
+
@return: A list of paths containing the desired model(s)
|
23 |
+
"""
|
24 |
+
output = []
|
25 |
+
|
26 |
+
if ext_filter is None:
|
27 |
+
ext_filter = []
|
28 |
+
|
29 |
+
try:
|
30 |
+
places = []
|
31 |
+
|
32 |
+
if command_path is not None and command_path != model_path:
|
33 |
+
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
|
34 |
+
if os.path.exists(pretrained_path):
|
35 |
+
print(f"Appending path: {pretrained_path}")
|
36 |
+
places.append(pretrained_path)
|
37 |
+
elif os.path.exists(command_path):
|
38 |
+
places.append(command_path)
|
39 |
+
|
40 |
+
places.append(model_path)
|
41 |
+
|
42 |
+
for place in places:
|
43 |
+
if os.path.exists(place):
|
44 |
+
for file in glob.iglob(place + '**/**', recursive=True):
|
45 |
+
full_path = file
|
46 |
+
if os.path.isdir(full_path):
|
47 |
+
continue
|
48 |
+
if os.path.islink(full_path) and not os.path.exists(full_path):
|
49 |
+
print(f"Skipping broken symlink: {full_path}")
|
50 |
+
continue
|
51 |
+
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
|
52 |
+
continue
|
53 |
+
if len(ext_filter) != 0:
|
54 |
+
model_name, extension = os.path.splitext(file)
|
55 |
+
if extension not in ext_filter:
|
56 |
+
continue
|
57 |
+
if file not in output:
|
58 |
+
output.append(full_path)
|
59 |
+
|
60 |
+
if model_url is not None and len(output) == 0:
|
61 |
+
if download_name is not None:
|
62 |
+
dl = load_file_from_url(model_url, model_path, True, download_name)
|
63 |
+
output.append(dl)
|
64 |
+
else:
|
65 |
+
output.append(model_url)
|
66 |
+
|
67 |
+
except Exception:
|
68 |
+
pass
|
69 |
+
|
70 |
+
return output
|
71 |
+
|
72 |
+
|
73 |
+
def friendly_name(file: str):
|
74 |
+
if "http" in file:
|
75 |
+
file = urlparse(file).path
|
76 |
+
|
77 |
+
file = os.path.basename(file)
|
78 |
+
model_name, extension = os.path.splitext(file)
|
79 |
+
return model_name
|
80 |
+
|
81 |
+
|
82 |
+
def cleanup_models():
|
83 |
+
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
|
84 |
+
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
|
85 |
+
# somehow auto-register and just do these things...
|
86 |
+
root_path = script_path
|
87 |
+
src_path = models_path
|
88 |
+
dest_path = os.path.join(models_path, "Stable-diffusion")
|
89 |
+
move_files(src_path, dest_path, ".ckpt")
|
90 |
+
move_files(src_path, dest_path, ".safetensors")
|
91 |
+
src_path = os.path.join(root_path, "ESRGAN")
|
92 |
+
dest_path = os.path.join(models_path, "ESRGAN")
|
93 |
+
move_files(src_path, dest_path)
|
94 |
+
src_path = os.path.join(models_path, "BSRGAN")
|
95 |
+
dest_path = os.path.join(models_path, "ESRGAN")
|
96 |
+
move_files(src_path, dest_path, ".pth")
|
97 |
+
src_path = os.path.join(root_path, "gfpgan")
|
98 |
+
dest_path = os.path.join(models_path, "GFPGAN")
|
99 |
+
move_files(src_path, dest_path)
|
100 |
+
src_path = os.path.join(root_path, "SwinIR")
|
101 |
+
dest_path = os.path.join(models_path, "SwinIR")
|
102 |
+
move_files(src_path, dest_path)
|
103 |
+
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
|
104 |
+
dest_path = os.path.join(models_path, "LDSR")
|
105 |
+
move_files(src_path, dest_path)
|
106 |
+
|
107 |
+
|
108 |
+
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
109 |
+
try:
|
110 |
+
if not os.path.exists(dest_path):
|
111 |
+
os.makedirs(dest_path)
|
112 |
+
if os.path.exists(src_path):
|
113 |
+
for file in os.listdir(src_path):
|
114 |
+
fullpath = os.path.join(src_path, file)
|
115 |
+
if os.path.isfile(fullpath):
|
116 |
+
if ext_filter is not None:
|
117 |
+
if ext_filter not in file:
|
118 |
+
continue
|
119 |
+
print(f"Moving {file} from {src_path} to {dest_path}.")
|
120 |
+
try:
|
121 |
+
shutil.move(fullpath, dest_path)
|
122 |
+
except:
|
123 |
+
pass
|
124 |
+
if len(os.listdir(src_path)) == 0:
|
125 |
+
print(f"Removing empty folder: {src_path}")
|
126 |
+
shutil.rmtree(src_path, True)
|
127 |
+
except:
|
128 |
+
pass
|
129 |
+
|
130 |
+
|
131 |
+
builtin_upscaler_classes = []
|
132 |
+
forbidden_upscaler_classes = set()
|
133 |
+
|
134 |
+
|
135 |
+
def list_builtin_upscalers():
|
136 |
+
load_upscalers()
|
137 |
+
|
138 |
+
builtin_upscaler_classes.clear()
|
139 |
+
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
|
140 |
+
|
141 |
+
|
142 |
+
def forbid_loaded_nonbuiltin_upscalers():
|
143 |
+
for cls in Upscaler.__subclasses__():
|
144 |
+
if cls not in builtin_upscaler_classes:
|
145 |
+
forbidden_upscaler_classes.add(cls)
|
146 |
+
|
147 |
+
|
148 |
+
def load_upscalers():
|
149 |
+
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
150 |
+
# so we'll try to import any _model.py files before looking in __subclasses__
|
151 |
+
modules_dir = os.path.join(shared.script_path, "modules")
|
152 |
+
for file in os.listdir(modules_dir):
|
153 |
+
if "_model.py" in file:
|
154 |
+
model_name = file.replace("_model.py", "")
|
155 |
+
full_model = f"modules.{model_name}_model"
|
156 |
+
try:
|
157 |
+
importlib.import_module(full_model)
|
158 |
+
except:
|
159 |
+
pass
|
160 |
+
|
161 |
+
datas = []
|
162 |
+
commandline_options = vars(shared.cmd_opts)
|
163 |
+
for cls in Upscaler.__subclasses__():
|
164 |
+
if cls in forbidden_upscaler_classes:
|
165 |
+
continue
|
166 |
+
|
167 |
+
name = cls.__name__
|
168 |
+
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
169 |
+
scaler = cls(commandline_options.get(cmd_name, None))
|
170 |
+
datas += scaler.scalers
|
171 |
+
|
172 |
+
shared.sd_upscalers = datas
|
sd/stable-diffusion-webui/modules/ngrok.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pyngrok import ngrok, conf, exception
|
2 |
+
|
3 |
+
def connect(token, port, region):
|
4 |
+
account = None
|
5 |
+
if token is None:
|
6 |
+
token = 'None'
|
7 |
+
else:
|
8 |
+
if ':' in token:
|
9 |
+
# token = authtoken:username:password
|
10 |
+
account = token.split(':')[1] + ':' + token.split(':')[-1]
|
11 |
+
token = token.split(':')[0]
|
12 |
+
|
13 |
+
config = conf.PyngrokConfig(
|
14 |
+
auth_token=token, region=region
|
15 |
+
)
|
16 |
+
try:
|
17 |
+
if account is None:
|
18 |
+
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
19 |
+
else:
|
20 |
+
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
|
21 |
+
except exception.PyngrokNgrokError:
|
22 |
+
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
|
23 |
+
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
|
24 |
+
else:
|
25 |
+
print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
|
26 |
+
'You can use this link after the launch is complete.')
|
sd/stable-diffusion-webui/modules/paths.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import modules.safe
|
5 |
+
|
6 |
+
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
7 |
+
|
8 |
+
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
9 |
+
parser = argparse.ArgumentParser(add_help=False)
|
10 |
+
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
11 |
+
cmd_opts_pre = parser.parse_known_args()[0]
|
12 |
+
data_path = cmd_opts_pre.data_dir
|
13 |
+
models_path = os.path.join(data_path, "models")
|
14 |
+
|
15 |
+
# data_path = cmd_opts_pre.data
|
16 |
+
sys.path.insert(0, script_path)
|
17 |
+
|
18 |
+
# search for directory of stable diffusion in following places
|
19 |
+
sd_path = None
|
20 |
+
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
|
21 |
+
for possible_sd_path in possible_sd_paths:
|
22 |
+
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
23 |
+
sd_path = os.path.abspath(possible_sd_path)
|
24 |
+
break
|
25 |
+
|
26 |
+
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
|
27 |
+
|
28 |
+
path_dirs = [
|
29 |
+
(sd_path, 'ldm', 'Stable Diffusion', []),
|
30 |
+
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
|
31 |
+
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
32 |
+
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
33 |
+
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
34 |
+
]
|
35 |
+
|
36 |
+
paths = {}
|
37 |
+
|
38 |
+
for d, must_exist, what, options in path_dirs:
|
39 |
+
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
|
40 |
+
if not os.path.exists(must_exist_path):
|
41 |
+
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
|
42 |
+
else:
|
43 |
+
d = os.path.abspath(d)
|
44 |
+
if "atstart" in options:
|
45 |
+
sys.path.insert(0, d)
|
46 |
+
else:
|
47 |
+
sys.path.append(d)
|
48 |
+
paths[what] = d
|
49 |
+
|
50 |
+
|
51 |
+
class Prioritize:
|
52 |
+
def __init__(self, name):
|
53 |
+
self.name = name
|
54 |
+
self.path = None
|
55 |
+
|
56 |
+
def __enter__(self):
|
57 |
+
self.path = sys.path.copy()
|
58 |
+
sys.path = [paths[self.name]] + sys.path
|
59 |
+
|
60 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
61 |
+
sys.path = self.path
|
62 |
+
self.path = None
|
sd/stable-diffusion-webui/modules/postprocessing.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
|
6 |
+
from modules.shared import opts
|
7 |
+
|
8 |
+
|
9 |
+
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
10 |
+
devices.torch_gc()
|
11 |
+
|
12 |
+
shared.state.begin()
|
13 |
+
shared.state.job = 'extras'
|
14 |
+
|
15 |
+
image_data = []
|
16 |
+
image_names = []
|
17 |
+
outputs = []
|
18 |
+
|
19 |
+
if extras_mode == 1:
|
20 |
+
for img in image_folder:
|
21 |
+
image = Image.open(img)
|
22 |
+
image_data.append(image)
|
23 |
+
image_names.append(os.path.splitext(img.orig_name)[0])
|
24 |
+
elif extras_mode == 2:
|
25 |
+
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
26 |
+
assert input_dir, 'input directory not selected'
|
27 |
+
|
28 |
+
image_list = shared.listfiles(input_dir)
|
29 |
+
for filename in image_list:
|
30 |
+
try:
|
31 |
+
image = Image.open(filename)
|
32 |
+
except Exception:
|
33 |
+
continue
|
34 |
+
image_data.append(image)
|
35 |
+
image_names.append(filename)
|
36 |
+
else:
|
37 |
+
assert image, 'image not selected'
|
38 |
+
|
39 |
+
image_data.append(image)
|
40 |
+
image_names.append(None)
|
41 |
+
|
42 |
+
if extras_mode == 2 and output_dir != '':
|
43 |
+
outpath = output_dir
|
44 |
+
else:
|
45 |
+
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
46 |
+
|
47 |
+
infotext = ''
|
48 |
+
|
49 |
+
for image, name in zip(image_data, image_names):
|
50 |
+
shared.state.textinfo = name
|
51 |
+
|
52 |
+
existing_pnginfo = image.info or {}
|
53 |
+
|
54 |
+
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
55 |
+
|
56 |
+
scripts.scripts_postproc.run(pp, args)
|
57 |
+
|
58 |
+
if opts.use_original_name_batch and name is not None:
|
59 |
+
basename = os.path.splitext(os.path.basename(name))[0]
|
60 |
+
else:
|
61 |
+
basename = ''
|
62 |
+
|
63 |
+
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
|
64 |
+
|
65 |
+
if opts.enable_pnginfo:
|
66 |
+
pp.image.info = existing_pnginfo
|
67 |
+
pp.image.info["postprocessing"] = infotext
|
68 |
+
|
69 |
+
if save_output:
|
70 |
+
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
|
71 |
+
|
72 |
+
if extras_mode != 2 or show_extras_results:
|
73 |
+
outputs.append(pp.image)
|
74 |
+
|
75 |
+
devices.torch_gc()
|
76 |
+
|
77 |
+
return outputs, ui_common.plaintext_to_html(infotext), ''
|
78 |
+
|
79 |
+
|
80 |
+
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
81 |
+
"""old handler for API"""
|
82 |
+
|
83 |
+
args = scripts.scripts_postproc.create_args_for_run({
|
84 |
+
"Upscale": {
|
85 |
+
"upscale_mode": resize_mode,
|
86 |
+
"upscale_by": upscaling_resize,
|
87 |
+
"upscale_to_width": upscaling_resize_w,
|
88 |
+
"upscale_to_height": upscaling_resize_h,
|
89 |
+
"upscale_crop": upscaling_crop,
|
90 |
+
"upscaler_1_name": extras_upscaler_1,
|
91 |
+
"upscaler_2_name": extras_upscaler_2,
|
92 |
+
"upscaler_2_visibility": extras_upscaler_2_visibility,
|
93 |
+
},
|
94 |
+
"GFPGAN": {
|
95 |
+
"gfpgan_visibility": gfpgan_visibility,
|
96 |
+
},
|
97 |
+
"CodeFormer": {
|
98 |
+
"codeformer_visibility": codeformer_visibility,
|
99 |
+
"codeformer_weight": codeformer_weight,
|
100 |
+
},
|
101 |
+
})
|
102 |
+
|
103 |
+
return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
|
sd/stable-diffusion-webui/modules/processing.py
ADDED
@@ -0,0 +1,1056 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image, ImageFilter, ImageOps
|
10 |
+
import random
|
11 |
+
import cv2
|
12 |
+
from skimage import exposure
|
13 |
+
from typing import Any, Dict, List, Optional
|
14 |
+
|
15 |
+
import modules.sd_hijack
|
16 |
+
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
|
17 |
+
from modules.sd_hijack import model_hijack
|
18 |
+
from modules.shared import opts, cmd_opts, state
|
19 |
+
import modules.shared as shared
|
20 |
+
import modules.paths as paths
|
21 |
+
import modules.face_restoration
|
22 |
+
import modules.images as images
|
23 |
+
import modules.styles
|
24 |
+
import modules.sd_models as sd_models
|
25 |
+
import modules.sd_vae as sd_vae
|
26 |
+
import logging
|
27 |
+
from ldm.data.util import AddMiDaS
|
28 |
+
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
29 |
+
|
30 |
+
from einops import repeat, rearrange
|
31 |
+
from blendmodes.blend import blendLayers, BlendType
|
32 |
+
|
33 |
+
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
34 |
+
opt_C = 4
|
35 |
+
opt_f = 8
|
36 |
+
|
37 |
+
|
38 |
+
def setup_color_correction(image):
|
39 |
+
logging.info("Calibrating color correction.")
|
40 |
+
correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
|
41 |
+
return correction_target
|
42 |
+
|
43 |
+
|
44 |
+
def apply_color_correction(correction, original_image):
|
45 |
+
logging.info("Applying color correction.")
|
46 |
+
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
47 |
+
cv2.cvtColor(
|
48 |
+
np.asarray(original_image),
|
49 |
+
cv2.COLOR_RGB2LAB
|
50 |
+
),
|
51 |
+
correction,
|
52 |
+
channel_axis=2
|
53 |
+
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
54 |
+
|
55 |
+
image = blendLayers(image, original_image, BlendType.LUMINOSITY)
|
56 |
+
|
57 |
+
return image
|
58 |
+
|
59 |
+
|
60 |
+
def apply_overlay(image, paste_loc, index, overlays):
|
61 |
+
if overlays is None or index >= len(overlays):
|
62 |
+
return image
|
63 |
+
|
64 |
+
overlay = overlays[index]
|
65 |
+
|
66 |
+
if paste_loc is not None:
|
67 |
+
x, y, w, h = paste_loc
|
68 |
+
base_image = Image.new('RGBA', (overlay.width, overlay.height))
|
69 |
+
image = images.resize_image(1, image, w, h)
|
70 |
+
base_image.paste(image, (x, y))
|
71 |
+
image = base_image
|
72 |
+
|
73 |
+
image = image.convert('RGBA')
|
74 |
+
image.alpha_composite(overlay)
|
75 |
+
image = image.convert('RGB')
|
76 |
+
|
77 |
+
return image
|
78 |
+
|
79 |
+
|
80 |
+
def txt2img_image_conditioning(sd_model, x, width, height):
|
81 |
+
if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
|
82 |
+
# Dummy zero conditioning if we're not using inpainting model.
|
83 |
+
# Still takes up a bit of memory, but no encoder call.
|
84 |
+
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
85 |
+
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
86 |
+
|
87 |
+
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
88 |
+
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
89 |
+
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
|
90 |
+
|
91 |
+
# Add the fake full 1s mask to the first dimension.
|
92 |
+
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
93 |
+
image_conditioning = image_conditioning.to(x.dtype)
|
94 |
+
|
95 |
+
return image_conditioning
|
96 |
+
|
97 |
+
|
98 |
+
class StableDiffusionProcessing:
|
99 |
+
"""
|
100 |
+
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
101 |
+
"""
|
102 |
+
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
103 |
+
if sampler_index is not None:
|
104 |
+
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
105 |
+
|
106 |
+
self.outpath_samples: str = outpath_samples
|
107 |
+
self.outpath_grids: str = outpath_grids
|
108 |
+
self.prompt: str = prompt
|
109 |
+
self.prompt_for_display: str = None
|
110 |
+
self.negative_prompt: str = (negative_prompt or "")
|
111 |
+
self.styles: list = styles or []
|
112 |
+
self.seed: int = seed
|
113 |
+
self.subseed: int = subseed
|
114 |
+
self.subseed_strength: float = subseed_strength
|
115 |
+
self.seed_resize_from_h: int = seed_resize_from_h
|
116 |
+
self.seed_resize_from_w: int = seed_resize_from_w
|
117 |
+
self.sampler_name: str = sampler_name
|
118 |
+
self.batch_size: int = batch_size
|
119 |
+
self.n_iter: int = n_iter
|
120 |
+
self.steps: int = steps
|
121 |
+
self.cfg_scale: float = cfg_scale
|
122 |
+
self.width: int = width
|
123 |
+
self.height: int = height
|
124 |
+
self.restore_faces: bool = restore_faces
|
125 |
+
self.tiling: bool = tiling
|
126 |
+
self.do_not_save_samples: bool = do_not_save_samples
|
127 |
+
self.do_not_save_grid: bool = do_not_save_grid
|
128 |
+
self.extra_generation_params: dict = extra_generation_params or {}
|
129 |
+
self.overlay_images = overlay_images
|
130 |
+
self.eta = eta
|
131 |
+
self.do_not_reload_embeddings = do_not_reload_embeddings
|
132 |
+
self.paste_to = None
|
133 |
+
self.color_corrections = None
|
134 |
+
self.denoising_strength: float = denoising_strength
|
135 |
+
self.sampler_noise_scheduler_override = None
|
136 |
+
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
137 |
+
self.s_churn = s_churn or opts.s_churn
|
138 |
+
self.s_tmin = s_tmin or opts.s_tmin
|
139 |
+
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
140 |
+
self.s_noise = s_noise or opts.s_noise
|
141 |
+
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
142 |
+
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
143 |
+
self.is_using_inpainting_conditioning = False
|
144 |
+
self.disable_extra_networks = False
|
145 |
+
|
146 |
+
if not seed_enable_extras:
|
147 |
+
self.subseed = -1
|
148 |
+
self.subseed_strength = 0
|
149 |
+
self.seed_resize_from_h = 0
|
150 |
+
self.seed_resize_from_w = 0
|
151 |
+
|
152 |
+
self.scripts = None
|
153 |
+
self.script_args = script_args
|
154 |
+
self.all_prompts = None
|
155 |
+
self.all_negative_prompts = None
|
156 |
+
self.all_seeds = None
|
157 |
+
self.all_subseeds = None
|
158 |
+
self.iteration = 0
|
159 |
+
|
160 |
+
@property
|
161 |
+
def sd_model(self):
|
162 |
+
return shared.sd_model
|
163 |
+
|
164 |
+
def txt2img_image_conditioning(self, x, width=None, height=None):
|
165 |
+
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
166 |
+
|
167 |
+
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
168 |
+
|
169 |
+
def depth2img_image_conditioning(self, source_image):
|
170 |
+
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
171 |
+
transformer = AddMiDaS(model_type="dpt_hybrid")
|
172 |
+
transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
|
173 |
+
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
174 |
+
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
175 |
+
|
176 |
+
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
|
177 |
+
conditioning = torch.nn.functional.interpolate(
|
178 |
+
self.sd_model.depth_model(midas_in),
|
179 |
+
size=conditioning_image.shape[2:],
|
180 |
+
mode="bicubic",
|
181 |
+
align_corners=False,
|
182 |
+
)
|
183 |
+
|
184 |
+
(depth_min, depth_max) = torch.aminmax(conditioning)
|
185 |
+
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
186 |
+
return conditioning
|
187 |
+
|
188 |
+
def edit_image_conditioning(self, source_image):
|
189 |
+
conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
|
190 |
+
|
191 |
+
return conditioning_image
|
192 |
+
|
193 |
+
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
194 |
+
self.is_using_inpainting_conditioning = True
|
195 |
+
|
196 |
+
# Handle the different mask inputs
|
197 |
+
if image_mask is not None:
|
198 |
+
if torch.is_tensor(image_mask):
|
199 |
+
conditioning_mask = image_mask
|
200 |
+
else:
|
201 |
+
conditioning_mask = np.array(image_mask.convert("L"))
|
202 |
+
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
203 |
+
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
204 |
+
|
205 |
+
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
206 |
+
conditioning_mask = torch.round(conditioning_mask)
|
207 |
+
else:
|
208 |
+
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
|
209 |
+
|
210 |
+
# Create another latent image, this time with a masked version of the original input.
|
211 |
+
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
|
212 |
+
conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
|
213 |
+
conditioning_image = torch.lerp(
|
214 |
+
source_image,
|
215 |
+
source_image * (1.0 - conditioning_mask),
|
216 |
+
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
|
217 |
+
)
|
218 |
+
|
219 |
+
# Encode the new masked image using first stage of network.
|
220 |
+
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
221 |
+
|
222 |
+
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
223 |
+
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
224 |
+
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
225 |
+
image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
226 |
+
image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
227 |
+
|
228 |
+
return image_conditioning
|
229 |
+
|
230 |
+
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
231 |
+
source_image = devices.cond_cast_float(source_image)
|
232 |
+
|
233 |
+
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
234 |
+
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
235 |
+
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
236 |
+
return self.depth2img_image_conditioning(source_image)
|
237 |
+
|
238 |
+
if self.sd_model.cond_stage_key == "edit":
|
239 |
+
return self.edit_image_conditioning(source_image)
|
240 |
+
|
241 |
+
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
242 |
+
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
243 |
+
|
244 |
+
# Dummy zero conditioning if we're not using inpainting or depth model.
|
245 |
+
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
246 |
+
|
247 |
+
def init(self, all_prompts, all_seeds, all_subseeds):
|
248 |
+
pass
|
249 |
+
|
250 |
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
251 |
+
raise NotImplementedError()
|
252 |
+
|
253 |
+
def close(self):
|
254 |
+
self.sampler = None
|
255 |
+
|
256 |
+
|
257 |
+
class Processed:
|
258 |
+
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
259 |
+
self.images = images_list
|
260 |
+
self.prompt = p.prompt
|
261 |
+
self.negative_prompt = p.negative_prompt
|
262 |
+
self.seed = seed
|
263 |
+
self.subseed = subseed
|
264 |
+
self.subseed_strength = p.subseed_strength
|
265 |
+
self.info = info
|
266 |
+
self.comments = comments
|
267 |
+
self.width = p.width
|
268 |
+
self.height = p.height
|
269 |
+
self.sampler_name = p.sampler_name
|
270 |
+
self.cfg_scale = p.cfg_scale
|
271 |
+
self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
272 |
+
self.steps = p.steps
|
273 |
+
self.batch_size = p.batch_size
|
274 |
+
self.restore_faces = p.restore_faces
|
275 |
+
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
276 |
+
self.sd_model_hash = shared.sd_model.sd_model_hash
|
277 |
+
self.seed_resize_from_w = p.seed_resize_from_w
|
278 |
+
self.seed_resize_from_h = p.seed_resize_from_h
|
279 |
+
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
280 |
+
self.extra_generation_params = p.extra_generation_params
|
281 |
+
self.index_of_first_image = index_of_first_image
|
282 |
+
self.styles = p.styles
|
283 |
+
self.job_timestamp = state.job_timestamp
|
284 |
+
self.clip_skip = opts.CLIP_stop_at_last_layers
|
285 |
+
|
286 |
+
self.eta = p.eta
|
287 |
+
self.ddim_discretize = p.ddim_discretize
|
288 |
+
self.s_churn = p.s_churn
|
289 |
+
self.s_tmin = p.s_tmin
|
290 |
+
self.s_tmax = p.s_tmax
|
291 |
+
self.s_noise = p.s_noise
|
292 |
+
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
293 |
+
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
294 |
+
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
295 |
+
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
|
296 |
+
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
297 |
+
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
|
298 |
+
|
299 |
+
self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
|
300 |
+
self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
|
301 |
+
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
302 |
+
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
303 |
+
self.infotexts = infotexts or [info]
|
304 |
+
|
305 |
+
def js(self):
|
306 |
+
obj = {
|
307 |
+
"prompt": self.all_prompts[0],
|
308 |
+
"all_prompts": self.all_prompts,
|
309 |
+
"negative_prompt": self.all_negative_prompts[0],
|
310 |
+
"all_negative_prompts": self.all_negative_prompts,
|
311 |
+
"seed": self.seed,
|
312 |
+
"all_seeds": self.all_seeds,
|
313 |
+
"subseed": self.subseed,
|
314 |
+
"all_subseeds": self.all_subseeds,
|
315 |
+
"subseed_strength": self.subseed_strength,
|
316 |
+
"width": self.width,
|
317 |
+
"height": self.height,
|
318 |
+
"sampler_name": self.sampler_name,
|
319 |
+
"cfg_scale": self.cfg_scale,
|
320 |
+
"steps": self.steps,
|
321 |
+
"batch_size": self.batch_size,
|
322 |
+
"restore_faces": self.restore_faces,
|
323 |
+
"face_restoration_model": self.face_restoration_model,
|
324 |
+
"sd_model_hash": self.sd_model_hash,
|
325 |
+
"seed_resize_from_w": self.seed_resize_from_w,
|
326 |
+
"seed_resize_from_h": self.seed_resize_from_h,
|
327 |
+
"denoising_strength": self.denoising_strength,
|
328 |
+
"extra_generation_params": self.extra_generation_params,
|
329 |
+
"index_of_first_image": self.index_of_first_image,
|
330 |
+
"infotexts": self.infotexts,
|
331 |
+
"styles": self.styles,
|
332 |
+
"job_timestamp": self.job_timestamp,
|
333 |
+
"clip_skip": self.clip_skip,
|
334 |
+
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
|
335 |
+
}
|
336 |
+
|
337 |
+
return json.dumps(obj)
|
338 |
+
|
339 |
+
def infotext(self, p: StableDiffusionProcessing, index):
|
340 |
+
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
341 |
+
|
342 |
+
|
343 |
+
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
344 |
+
def slerp(val, low, high):
|
345 |
+
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
346 |
+
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
347 |
+
dot = (low_norm*high_norm).sum(1)
|
348 |
+
|
349 |
+
if dot.mean() > 0.9995:
|
350 |
+
return low * val + high * (1 - val)
|
351 |
+
|
352 |
+
omega = torch.acos(dot)
|
353 |
+
so = torch.sin(omega)
|
354 |
+
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
355 |
+
return res
|
356 |
+
|
357 |
+
|
358 |
+
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
359 |
+
eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
|
360 |
+
xs = []
|
361 |
+
|
362 |
+
# if we have multiple seeds, this means we are working with batch size>1; this then
|
363 |
+
# enables the generation of additional tensors with noise that the sampler will use during its processing.
|
364 |
+
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
|
365 |
+
# produce the same images as with two batches [100], [101].
|
366 |
+
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
|
367 |
+
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
368 |
+
else:
|
369 |
+
sampler_noises = None
|
370 |
+
|
371 |
+
for i, seed in enumerate(seeds):
|
372 |
+
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
373 |
+
|
374 |
+
subnoise = None
|
375 |
+
if subseeds is not None:
|
376 |
+
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
377 |
+
|
378 |
+
subnoise = devices.randn(subseed, noise_shape)
|
379 |
+
|
380 |
+
# randn results depend on device; gpu and cpu get different results for same seed;
|
381 |
+
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
382 |
+
# but the original script had it like this, so I do not dare change it for now because
|
383 |
+
# it will break everyone's seeds.
|
384 |
+
noise = devices.randn(seed, noise_shape)
|
385 |
+
|
386 |
+
if subnoise is not None:
|
387 |
+
noise = slerp(subseed_strength, noise, subnoise)
|
388 |
+
|
389 |
+
if noise_shape != shape:
|
390 |
+
x = devices.randn(seed, shape)
|
391 |
+
dx = (shape[2] - noise_shape[2]) // 2
|
392 |
+
dy = (shape[1] - noise_shape[1]) // 2
|
393 |
+
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
394 |
+
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
395 |
+
tx = 0 if dx < 0 else dx
|
396 |
+
ty = 0 if dy < 0 else dy
|
397 |
+
dx = max(-dx, 0)
|
398 |
+
dy = max(-dy, 0)
|
399 |
+
|
400 |
+
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
|
401 |
+
noise = x
|
402 |
+
|
403 |
+
if sampler_noises is not None:
|
404 |
+
cnt = p.sampler.number_of_needed_noises(p)
|
405 |
+
|
406 |
+
if eta_noise_seed_delta > 0:
|
407 |
+
torch.manual_seed(seed + eta_noise_seed_delta)
|
408 |
+
|
409 |
+
for j in range(cnt):
|
410 |
+
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
411 |
+
|
412 |
+
xs.append(noise)
|
413 |
+
|
414 |
+
if sampler_noises is not None:
|
415 |
+
p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
|
416 |
+
|
417 |
+
x = torch.stack(xs).to(shared.device)
|
418 |
+
return x
|
419 |
+
|
420 |
+
|
421 |
+
def decode_first_stage(model, x):
|
422 |
+
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
423 |
+
x = model.decode_first_stage(x)
|
424 |
+
|
425 |
+
return x
|
426 |
+
|
427 |
+
|
428 |
+
def get_fixed_seed(seed):
|
429 |
+
if seed is None or seed == '' or seed == -1:
|
430 |
+
return int(random.randrange(4294967294))
|
431 |
+
|
432 |
+
return seed
|
433 |
+
|
434 |
+
|
435 |
+
def fix_seed(p):
|
436 |
+
p.seed = get_fixed_seed(p.seed)
|
437 |
+
p.subseed = get_fixed_seed(p.subseed)
|
438 |
+
|
439 |
+
|
440 |
+
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
441 |
+
index = position_in_batch + iteration * p.batch_size
|
442 |
+
|
443 |
+
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
444 |
+
|
445 |
+
generation_params = {
|
446 |
+
"Steps": p.steps,
|
447 |
+
"Sampler": p.sampler_name,
|
448 |
+
"CFG scale": p.cfg_scale,
|
449 |
+
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
450 |
+
"Seed": all_seeds[index],
|
451 |
+
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
452 |
+
"Size": f"{p.width}x{p.height}",
|
453 |
+
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
454 |
+
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
455 |
+
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
456 |
+
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
457 |
+
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
458 |
+
"Denoising strength": getattr(p, 'denoising_strength', None),
|
459 |
+
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
460 |
+
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
461 |
+
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
462 |
+
}
|
463 |
+
|
464 |
+
generation_params.update(p.extra_generation_params)
|
465 |
+
|
466 |
+
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
467 |
+
|
468 |
+
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
|
469 |
+
|
470 |
+
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
471 |
+
|
472 |
+
|
473 |
+
def process_images(p: StableDiffusionProcessing) -> Processed:
|
474 |
+
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
475 |
+
|
476 |
+
try:
|
477 |
+
for k, v in p.override_settings.items():
|
478 |
+
setattr(opts, k, v)
|
479 |
+
|
480 |
+
if k == 'sd_model_checkpoint':
|
481 |
+
sd_models.reload_model_weights()
|
482 |
+
|
483 |
+
if k == 'sd_vae':
|
484 |
+
sd_vae.reload_vae_weights()
|
485 |
+
|
486 |
+
res = process_images_inner(p)
|
487 |
+
|
488 |
+
finally:
|
489 |
+
# restore opts to original state
|
490 |
+
if p.override_settings_restore_afterwards:
|
491 |
+
for k, v in stored_opts.items():
|
492 |
+
setattr(opts, k, v)
|
493 |
+
if k == 'sd_model_checkpoint':
|
494 |
+
sd_models.reload_model_weights()
|
495 |
+
|
496 |
+
if k == 'sd_vae':
|
497 |
+
sd_vae.reload_vae_weights()
|
498 |
+
|
499 |
+
return res
|
500 |
+
|
501 |
+
|
502 |
+
def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
503 |
+
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
504 |
+
|
505 |
+
if type(p.prompt) == list:
|
506 |
+
assert(len(p.prompt) > 0)
|
507 |
+
else:
|
508 |
+
assert p.prompt is not None
|
509 |
+
|
510 |
+
devices.torch_gc()
|
511 |
+
|
512 |
+
seed = get_fixed_seed(p.seed)
|
513 |
+
subseed = get_fixed_seed(p.subseed)
|
514 |
+
|
515 |
+
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
516 |
+
modules.sd_hijack.model_hijack.clear_comments()
|
517 |
+
|
518 |
+
comments = {}
|
519 |
+
|
520 |
+
if type(p.prompt) == list:
|
521 |
+
p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
|
522 |
+
else:
|
523 |
+
p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
|
524 |
+
|
525 |
+
if type(p.negative_prompt) == list:
|
526 |
+
p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
|
527 |
+
else:
|
528 |
+
p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
|
529 |
+
|
530 |
+
if type(seed) == list:
|
531 |
+
p.all_seeds = seed
|
532 |
+
else:
|
533 |
+
p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
|
534 |
+
|
535 |
+
if type(subseed) == list:
|
536 |
+
p.all_subseeds = subseed
|
537 |
+
else:
|
538 |
+
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
539 |
+
|
540 |
+
def infotext(iteration=0, position_in_batch=0):
|
541 |
+
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
542 |
+
|
543 |
+
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
544 |
+
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
545 |
+
|
546 |
+
if p.scripts is not None:
|
547 |
+
p.scripts.process(p)
|
548 |
+
|
549 |
+
infotexts = []
|
550 |
+
output_images = []
|
551 |
+
|
552 |
+
cached_uc = [None, None]
|
553 |
+
cached_c = [None, None]
|
554 |
+
|
555 |
+
def get_conds_with_caching(function, required_prompts, steps, cache):
|
556 |
+
"""
|
557 |
+
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
558 |
+
using a cache to store the result if the same arguments have been used before.
|
559 |
+
|
560 |
+
cache is an array containing two elements. The first element is a tuple
|
561 |
+
representing the previously used arguments, or None if no arguments
|
562 |
+
have been used before. The second element is where the previously
|
563 |
+
computed result is stored.
|
564 |
+
"""
|
565 |
+
|
566 |
+
if cache[0] is not None and (required_prompts, steps) == cache[0]:
|
567 |
+
return cache[1]
|
568 |
+
|
569 |
+
with devices.autocast():
|
570 |
+
cache[1] = function(shared.sd_model, required_prompts, steps)
|
571 |
+
|
572 |
+
cache[0] = (required_prompts, steps)
|
573 |
+
return cache[1]
|
574 |
+
|
575 |
+
with torch.no_grad(), p.sd_model.ema_scope():
|
576 |
+
with devices.autocast():
|
577 |
+
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
578 |
+
|
579 |
+
# for OSX, loading the model during sampling changes the generated picture, so it is loaded here
|
580 |
+
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
|
581 |
+
sd_vae_approx.model()
|
582 |
+
|
583 |
+
if state.job_count == -1:
|
584 |
+
state.job_count = p.n_iter
|
585 |
+
|
586 |
+
for n in range(p.n_iter):
|
587 |
+
p.iteration = n
|
588 |
+
|
589 |
+
if state.skipped:
|
590 |
+
state.skipped = False
|
591 |
+
|
592 |
+
if state.interrupted:
|
593 |
+
break
|
594 |
+
|
595 |
+
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
596 |
+
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
597 |
+
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
598 |
+
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
599 |
+
|
600 |
+
if len(prompts) == 0:
|
601 |
+
break
|
602 |
+
|
603 |
+
prompts, extra_network_data = extra_networks.parse_prompts(prompts)
|
604 |
+
|
605 |
+
if not p.disable_extra_networks:
|
606 |
+
with devices.autocast():
|
607 |
+
extra_networks.activate(p, extra_network_data)
|
608 |
+
|
609 |
+
if p.scripts is not None:
|
610 |
+
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
|
611 |
+
|
612 |
+
# params.txt should be saved after scripts.process_batch, since the
|
613 |
+
# infotext could be modified by that callback
|
614 |
+
# Example: a wildcard processed by process_batch sets an extra model
|
615 |
+
# strength, which is saved as "Model Strength: 1.0" in the infotext
|
616 |
+
if n == 0:
|
617 |
+
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
618 |
+
processed = Processed(p, [], p.seed, "")
|
619 |
+
file.write(processed.infotext(p, 0))
|
620 |
+
|
621 |
+
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
|
622 |
+
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
|
623 |
+
|
624 |
+
if len(model_hijack.comments) > 0:
|
625 |
+
for comment in model_hijack.comments:
|
626 |
+
comments[comment] = 1
|
627 |
+
|
628 |
+
if p.n_iter > 1:
|
629 |
+
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
630 |
+
|
631 |
+
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
632 |
+
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
633 |
+
|
634 |
+
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
|
635 |
+
for x in x_samples_ddim:
|
636 |
+
devices.test_for_nans(x, "vae")
|
637 |
+
|
638 |
+
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
639 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
640 |
+
|
641 |
+
del samples_ddim
|
642 |
+
|
643 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
644 |
+
lowvram.send_everything_to_cpu()
|
645 |
+
|
646 |
+
devices.torch_gc()
|
647 |
+
|
648 |
+
if p.scripts is not None:
|
649 |
+
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
650 |
+
|
651 |
+
for i, x_sample in enumerate(x_samples_ddim):
|
652 |
+
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
653 |
+
x_sample = x_sample.astype(np.uint8)
|
654 |
+
|
655 |
+
if p.restore_faces:
|
656 |
+
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
657 |
+
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
658 |
+
|
659 |
+
devices.torch_gc()
|
660 |
+
|
661 |
+
x_sample = modules.face_restoration.restore_faces(x_sample)
|
662 |
+
devices.torch_gc()
|
663 |
+
|
664 |
+
image = Image.fromarray(x_sample)
|
665 |
+
|
666 |
+
if p.scripts is not None:
|
667 |
+
pp = scripts.PostprocessImageArgs(image)
|
668 |
+
p.scripts.postprocess_image(p, pp)
|
669 |
+
image = pp.image
|
670 |
+
|
671 |
+
if p.color_corrections is not None and i < len(p.color_corrections):
|
672 |
+
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
673 |
+
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
674 |
+
images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
|
675 |
+
image = apply_color_correction(p.color_corrections[i], image)
|
676 |
+
|
677 |
+
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
678 |
+
|
679 |
+
if opts.samples_save and not p.do_not_save_samples:
|
680 |
+
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
681 |
+
|
682 |
+
text = infotext(n, i)
|
683 |
+
infotexts.append(text)
|
684 |
+
if opts.enable_pnginfo:
|
685 |
+
image.info["parameters"] = text
|
686 |
+
output_images.append(image)
|
687 |
+
|
688 |
+
del x_samples_ddim
|
689 |
+
|
690 |
+
devices.torch_gc()
|
691 |
+
|
692 |
+
state.nextjob()
|
693 |
+
|
694 |
+
p.color_corrections = None
|
695 |
+
|
696 |
+
index_of_first_image = 0
|
697 |
+
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
|
698 |
+
if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
|
699 |
+
grid = images.image_grid(output_images, p.batch_size)
|
700 |
+
|
701 |
+
if opts.return_grid:
|
702 |
+
text = infotext()
|
703 |
+
infotexts.insert(0, text)
|
704 |
+
if opts.enable_pnginfo:
|
705 |
+
grid.info["parameters"] = text
|
706 |
+
output_images.insert(0, grid)
|
707 |
+
index_of_first_image = 1
|
708 |
+
|
709 |
+
if opts.grid_save:
|
710 |
+
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
711 |
+
|
712 |
+
if not p.disable_extra_networks:
|
713 |
+
extra_networks.deactivate(p, extra_network_data)
|
714 |
+
|
715 |
+
devices.torch_gc()
|
716 |
+
|
717 |
+
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
718 |
+
|
719 |
+
if p.scripts is not None:
|
720 |
+
p.scripts.postprocess(p, res)
|
721 |
+
|
722 |
+
return res
|
723 |
+
|
724 |
+
|
725 |
+
def old_hires_fix_first_pass_dimensions(width, height):
|
726 |
+
"""old algorithm for auto-calculating first pass size"""
|
727 |
+
|
728 |
+
desired_pixel_count = 512 * 512
|
729 |
+
actual_pixel_count = width * height
|
730 |
+
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
731 |
+
width = math.ceil(scale * width / 64) * 64
|
732 |
+
height = math.ceil(scale * height / 64) * 64
|
733 |
+
|
734 |
+
return width, height
|
735 |
+
|
736 |
+
|
737 |
+
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
738 |
+
sampler = None
|
739 |
+
|
740 |
+
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
|
741 |
+
super().__init__(**kwargs)
|
742 |
+
self.enable_hr = enable_hr
|
743 |
+
self.denoising_strength = denoising_strength
|
744 |
+
self.hr_scale = hr_scale
|
745 |
+
self.hr_upscaler = hr_upscaler
|
746 |
+
self.hr_second_pass_steps = hr_second_pass_steps
|
747 |
+
self.hr_resize_x = hr_resize_x
|
748 |
+
self.hr_resize_y = hr_resize_y
|
749 |
+
self.hr_upscale_to_x = hr_resize_x
|
750 |
+
self.hr_upscale_to_y = hr_resize_y
|
751 |
+
|
752 |
+
if firstphase_width != 0 or firstphase_height != 0:
|
753 |
+
self.hr_upscale_to_x = self.width
|
754 |
+
self.hr_upscale_to_y = self.height
|
755 |
+
self.width = firstphase_width
|
756 |
+
self.height = firstphase_height
|
757 |
+
|
758 |
+
self.truncate_x = 0
|
759 |
+
self.truncate_y = 0
|
760 |
+
self.applied_old_hires_behavior_to = None
|
761 |
+
|
762 |
+
def init(self, all_prompts, all_seeds, all_subseeds):
|
763 |
+
if self.enable_hr:
|
764 |
+
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
765 |
+
self.hr_resize_x = self.width
|
766 |
+
self.hr_resize_y = self.height
|
767 |
+
self.hr_upscale_to_x = self.width
|
768 |
+
self.hr_upscale_to_y = self.height
|
769 |
+
|
770 |
+
self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
|
771 |
+
self.applied_old_hires_behavior_to = (self.width, self.height)
|
772 |
+
|
773 |
+
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
|
774 |
+
self.extra_generation_params["Hires upscale"] = self.hr_scale
|
775 |
+
self.hr_upscale_to_x = int(self.width * self.hr_scale)
|
776 |
+
self.hr_upscale_to_y = int(self.height * self.hr_scale)
|
777 |
+
else:
|
778 |
+
self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
|
779 |
+
|
780 |
+
if self.hr_resize_y == 0:
|
781 |
+
self.hr_upscale_to_x = self.hr_resize_x
|
782 |
+
self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
|
783 |
+
elif self.hr_resize_x == 0:
|
784 |
+
self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
|
785 |
+
self.hr_upscale_to_y = self.hr_resize_y
|
786 |
+
else:
|
787 |
+
target_w = self.hr_resize_x
|
788 |
+
target_h = self.hr_resize_y
|
789 |
+
src_ratio = self.width / self.height
|
790 |
+
dst_ratio = self.hr_resize_x / self.hr_resize_y
|
791 |
+
|
792 |
+
if src_ratio < dst_ratio:
|
793 |
+
self.hr_upscale_to_x = self.hr_resize_x
|
794 |
+
self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
|
795 |
+
else:
|
796 |
+
self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
|
797 |
+
self.hr_upscale_to_y = self.hr_resize_y
|
798 |
+
|
799 |
+
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
|
800 |
+
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
|
801 |
+
|
802 |
+
# special case: the user has chosen to do nothing
|
803 |
+
if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
|
804 |
+
self.enable_hr = False
|
805 |
+
self.denoising_strength = None
|
806 |
+
self.extra_generation_params.pop("Hires upscale", None)
|
807 |
+
self.extra_generation_params.pop("Hires resize", None)
|
808 |
+
return
|
809 |
+
|
810 |
+
if not state.processing_has_refined_job_count:
|
811 |
+
if state.job_count == -1:
|
812 |
+
state.job_count = self.n_iter
|
813 |
+
|
814 |
+
shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
|
815 |
+
state.job_count = state.job_count * 2
|
816 |
+
state.processing_has_refined_job_count = True
|
817 |
+
|
818 |
+
if self.hr_second_pass_steps:
|
819 |
+
self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
|
820 |
+
|
821 |
+
if self.hr_upscaler is not None:
|
822 |
+
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
|
823 |
+
|
824 |
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
825 |
+
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
826 |
+
|
827 |
+
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
828 |
+
if self.enable_hr and latent_scale_mode is None:
|
829 |
+
assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
|
830 |
+
|
831 |
+
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
832 |
+
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
833 |
+
|
834 |
+
if not self.enable_hr:
|
835 |
+
return samples
|
836 |
+
|
837 |
+
target_width = self.hr_upscale_to_x
|
838 |
+
target_height = self.hr_upscale_to_y
|
839 |
+
|
840 |
+
def save_intermediate(image, index):
|
841 |
+
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
|
842 |
+
|
843 |
+
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
|
844 |
+
return
|
845 |
+
|
846 |
+
if not isinstance(image, Image.Image):
|
847 |
+
image = sd_samplers.sample_to_image(image, index, approximation=0)
|
848 |
+
|
849 |
+
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
|
850 |
+
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
|
851 |
+
|
852 |
+
if latent_scale_mode is not None:
|
853 |
+
for i in range(samples.shape[0]):
|
854 |
+
save_intermediate(samples, i)
|
855 |
+
|
856 |
+
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
|
857 |
+
|
858 |
+
# Avoid making the inpainting conditioning unless necessary as
|
859 |
+
# this does need some extra compute to decode / encode the image again.
|
860 |
+
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
|
861 |
+
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
|
862 |
+
else:
|
863 |
+
image_conditioning = self.txt2img_image_conditioning(samples)
|
864 |
+
else:
|
865 |
+
decoded_samples = decode_first_stage(self.sd_model, samples)
|
866 |
+
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
867 |
+
|
868 |
+
batch_images = []
|
869 |
+
for i, x_sample in enumerate(lowres_samples):
|
870 |
+
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
871 |
+
x_sample = x_sample.astype(np.uint8)
|
872 |
+
image = Image.fromarray(x_sample)
|
873 |
+
|
874 |
+
save_intermediate(image, i)
|
875 |
+
|
876 |
+
image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
|
877 |
+
image = np.array(image).astype(np.float32) / 255.0
|
878 |
+
image = np.moveaxis(image, 2, 0)
|
879 |
+
batch_images.append(image)
|
880 |
+
|
881 |
+
decoded_samples = torch.from_numpy(np.array(batch_images))
|
882 |
+
decoded_samples = decoded_samples.to(shared.device)
|
883 |
+
decoded_samples = 2. * decoded_samples - 1.
|
884 |
+
|
885 |
+
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
886 |
+
|
887 |
+
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
|
888 |
+
|
889 |
+
shared.state.nextjob()
|
890 |
+
|
891 |
+
img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
|
892 |
+
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
893 |
+
|
894 |
+
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
895 |
+
|
896 |
+
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
897 |
+
|
898 |
+
# GC now before running the next img2img to prevent running out of memory
|
899 |
+
x = None
|
900 |
+
devices.torch_gc()
|
901 |
+
|
902 |
+
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
903 |
+
|
904 |
+
return samples
|
905 |
+
|
906 |
+
|
907 |
+
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
908 |
+
sampler = None
|
909 |
+
|
910 |
+
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
911 |
+
super().__init__(**kwargs)
|
912 |
+
|
913 |
+
self.init_images = init_images
|
914 |
+
self.resize_mode: int = resize_mode
|
915 |
+
self.denoising_strength: float = denoising_strength
|
916 |
+
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
917 |
+
self.init_latent = None
|
918 |
+
self.image_mask = mask
|
919 |
+
self.latent_mask = None
|
920 |
+
self.mask_for_overlay = None
|
921 |
+
self.mask_blur = mask_blur
|
922 |
+
self.inpainting_fill = inpainting_fill
|
923 |
+
self.inpaint_full_res = inpaint_full_res
|
924 |
+
self.inpaint_full_res_padding = inpaint_full_res_padding
|
925 |
+
self.inpainting_mask_invert = inpainting_mask_invert
|
926 |
+
self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
|
927 |
+
self.mask = None
|
928 |
+
self.nmask = None
|
929 |
+
self.image_conditioning = None
|
930 |
+
|
931 |
+
def init(self, all_prompts, all_seeds, all_subseeds):
|
932 |
+
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
933 |
+
crop_region = None
|
934 |
+
|
935 |
+
image_mask = self.image_mask
|
936 |
+
|
937 |
+
if image_mask is not None:
|
938 |
+
image_mask = image_mask.convert('L')
|
939 |
+
|
940 |
+
if self.inpainting_mask_invert:
|
941 |
+
image_mask = ImageOps.invert(image_mask)
|
942 |
+
|
943 |
+
if self.mask_blur > 0:
|
944 |
+
image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
945 |
+
|
946 |
+
if self.inpaint_full_res:
|
947 |
+
self.mask_for_overlay = image_mask
|
948 |
+
mask = image_mask.convert('L')
|
949 |
+
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
|
950 |
+
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
|
951 |
+
x1, y1, x2, y2 = crop_region
|
952 |
+
|
953 |
+
mask = mask.crop(crop_region)
|
954 |
+
image_mask = images.resize_image(2, mask, self.width, self.height)
|
955 |
+
self.paste_to = (x1, y1, x2-x1, y2-y1)
|
956 |
+
else:
|
957 |
+
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
|
958 |
+
np_mask = np.array(image_mask)
|
959 |
+
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
|
960 |
+
self.mask_for_overlay = Image.fromarray(np_mask)
|
961 |
+
|
962 |
+
self.overlay_images = []
|
963 |
+
|
964 |
+
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
|
965 |
+
|
966 |
+
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
|
967 |
+
if add_color_corrections:
|
968 |
+
self.color_corrections = []
|
969 |
+
imgs = []
|
970 |
+
for img in self.init_images:
|
971 |
+
image = images.flatten(img, opts.img2img_background_color)
|
972 |
+
|
973 |
+
if crop_region is None and self.resize_mode != 3:
|
974 |
+
image = images.resize_image(self.resize_mode, image, self.width, self.height)
|
975 |
+
|
976 |
+
if image_mask is not None:
|
977 |
+
image_masked = Image.new('RGBa', (image.width, image.height))
|
978 |
+
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
|
979 |
+
|
980 |
+
self.overlay_images.append(image_masked.convert('RGBA'))
|
981 |
+
|
982 |
+
# crop_region is not None if we are doing inpaint full res
|
983 |
+
if crop_region is not None:
|
984 |
+
image = image.crop(crop_region)
|
985 |
+
image = images.resize_image(2, image, self.width, self.height)
|
986 |
+
|
987 |
+
if image_mask is not None:
|
988 |
+
if self.inpainting_fill != 1:
|
989 |
+
image = masking.fill(image, latent_mask)
|
990 |
+
|
991 |
+
if add_color_corrections:
|
992 |
+
self.color_corrections.append(setup_color_correction(image))
|
993 |
+
|
994 |
+
image = np.array(image).astype(np.float32) / 255.0
|
995 |
+
image = np.moveaxis(image, 2, 0)
|
996 |
+
|
997 |
+
imgs.append(image)
|
998 |
+
|
999 |
+
if len(imgs) == 1:
|
1000 |
+
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
|
1001 |
+
if self.overlay_images is not None:
|
1002 |
+
self.overlay_images = self.overlay_images * self.batch_size
|
1003 |
+
|
1004 |
+
if self.color_corrections is not None and len(self.color_corrections) == 1:
|
1005 |
+
self.color_corrections = self.color_corrections * self.batch_size
|
1006 |
+
|
1007 |
+
elif len(imgs) <= self.batch_size:
|
1008 |
+
self.batch_size = len(imgs)
|
1009 |
+
batch_images = np.array(imgs)
|
1010 |
+
else:
|
1011 |
+
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
1012 |
+
|
1013 |
+
image = torch.from_numpy(batch_images)
|
1014 |
+
image = 2. * image - 1.
|
1015 |
+
image = image.to(shared.device)
|
1016 |
+
|
1017 |
+
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
1018 |
+
|
1019 |
+
if self.resize_mode == 3:
|
1020 |
+
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
1021 |
+
|
1022 |
+
if image_mask is not None:
|
1023 |
+
init_mask = latent_mask
|
1024 |
+
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
1025 |
+
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
1026 |
+
latmask = latmask[0]
|
1027 |
+
latmask = np.around(latmask)
|
1028 |
+
latmask = np.tile(latmask[None], (4, 1, 1))
|
1029 |
+
|
1030 |
+
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
1031 |
+
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
|
1032 |
+
|
1033 |
+
# this needs to be fixed to be done in sample() using actual seeds for batches
|
1034 |
+
if self.inpainting_fill == 2:
|
1035 |
+
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
1036 |
+
elif self.inpainting_fill == 3:
|
1037 |
+
self.init_latent = self.init_latent * self.mask
|
1038 |
+
|
1039 |
+
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
|
1040 |
+
|
1041 |
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
1042 |
+
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
1043 |
+
|
1044 |
+
if self.initial_noise_multiplier != 1.0:
|
1045 |
+
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
1046 |
+
x *= self.initial_noise_multiplier
|
1047 |
+
|
1048 |
+
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
1049 |
+
|
1050 |
+
if self.mask is not None:
|
1051 |
+
samples = samples * self.nmask + self.init_latent * self.mask
|
1052 |
+
|
1053 |
+
del x
|
1054 |
+
devices.torch_gc()
|
1055 |
+
|
1056 |
+
return samples
|
sd/stable-diffusion-webui/modules/progress.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import time
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
|
8 |
+
from modules.shared import opts
|
9 |
+
|
10 |
+
import modules.shared as shared
|
11 |
+
|
12 |
+
|
13 |
+
current_task = None
|
14 |
+
pending_tasks = {}
|
15 |
+
finished_tasks = []
|
16 |
+
|
17 |
+
|
18 |
+
def start_task(id_task):
|
19 |
+
global current_task
|
20 |
+
|
21 |
+
current_task = id_task
|
22 |
+
pending_tasks.pop(id_task, None)
|
23 |
+
|
24 |
+
|
25 |
+
def finish_task(id_task):
|
26 |
+
global current_task
|
27 |
+
|
28 |
+
if current_task == id_task:
|
29 |
+
current_task = None
|
30 |
+
|
31 |
+
finished_tasks.append(id_task)
|
32 |
+
if len(finished_tasks) > 16:
|
33 |
+
finished_tasks.pop(0)
|
34 |
+
|
35 |
+
|
36 |
+
def add_task_to_queue(id_job):
|
37 |
+
pending_tasks[id_job] = time.time()
|
38 |
+
|
39 |
+
|
40 |
+
class ProgressRequest(BaseModel):
|
41 |
+
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
42 |
+
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
|
43 |
+
|
44 |
+
|
45 |
+
class ProgressResponse(BaseModel):
|
46 |
+
active: bool = Field(title="Whether the task is being worked on right now")
|
47 |
+
queued: bool = Field(title="Whether the task is in queue")
|
48 |
+
completed: bool = Field(title="Whether the task has already finished")
|
49 |
+
progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
|
50 |
+
eta: float = Field(default=None, title="ETA in secs")
|
51 |
+
live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
|
52 |
+
id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
|
53 |
+
textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
54 |
+
|
55 |
+
|
56 |
+
def setup_progress_api(app):
|
57 |
+
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
|
58 |
+
|
59 |
+
|
60 |
+
def progressapi(req: ProgressRequest):
|
61 |
+
active = req.id_task == current_task
|
62 |
+
queued = req.id_task in pending_tasks
|
63 |
+
completed = req.id_task in finished_tasks
|
64 |
+
|
65 |
+
if not active:
|
66 |
+
return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
|
67 |
+
|
68 |
+
progress = 0
|
69 |
+
|
70 |
+
job_count, job_no = shared.state.job_count, shared.state.job_no
|
71 |
+
sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
|
72 |
+
|
73 |
+
if job_count > 0:
|
74 |
+
progress += job_no / job_count
|
75 |
+
if sampling_steps > 0 and job_count > 0:
|
76 |
+
progress += 1 / job_count * sampling_step / sampling_steps
|
77 |
+
|
78 |
+
progress = min(progress, 1)
|
79 |
+
|
80 |
+
elapsed_since_start = time.time() - shared.state.time_start
|
81 |
+
predicted_duration = elapsed_since_start / progress if progress > 0 else None
|
82 |
+
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
|
83 |
+
|
84 |
+
id_live_preview = req.id_live_preview
|
85 |
+
shared.state.set_current_image()
|
86 |
+
if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
|
87 |
+
image = shared.state.current_image
|
88 |
+
if image is not None:
|
89 |
+
buffered = io.BytesIO()
|
90 |
+
image.save(buffered, format="png")
|
91 |
+
live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
|
92 |
+
id_live_preview = shared.state.id_live_preview
|
93 |
+
else:
|
94 |
+
live_preview = None
|
95 |
+
else:
|
96 |
+
live_preview = None
|
97 |
+
|
98 |
+
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
99 |
+
|
sd/stable-diffusion-webui/modules/prompt_parser.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from collections import namedtuple
|
3 |
+
from typing import List
|
4 |
+
import lark
|
5 |
+
|
6 |
+
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
|
7 |
+
# will be represented with prompt_schedule like this (assuming steps=100):
|
8 |
+
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
|
9 |
+
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
|
10 |
+
# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
|
11 |
+
# [75, 'fantasy landscape with a lake and an oak in background masterful']
|
12 |
+
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
|
13 |
+
|
14 |
+
schedule_parser = lark.Lark(r"""
|
15 |
+
!start: (prompt | /[][():]/+)*
|
16 |
+
prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
|
17 |
+
!emphasized: "(" prompt ")"
|
18 |
+
| "(" prompt ":" prompt ")"
|
19 |
+
| "[" prompt "]"
|
20 |
+
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
21 |
+
alternate: "[" prompt ("|" prompt)+ "]"
|
22 |
+
WHITESPACE: /\s+/
|
23 |
+
plain: /([^\\\[\]():|]|\\.)+/
|
24 |
+
%import common.SIGNED_NUMBER -> NUMBER
|
25 |
+
""")
|
26 |
+
|
27 |
+
def get_learned_conditioning_prompt_schedules(prompts, steps):
|
28 |
+
"""
|
29 |
+
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
30 |
+
>>> g("test")
|
31 |
+
[[10, 'test']]
|
32 |
+
>>> g("a [b:3]")
|
33 |
+
[[3, 'a '], [10, 'a b']]
|
34 |
+
>>> g("a [b: 3]")
|
35 |
+
[[3, 'a '], [10, 'a b']]
|
36 |
+
>>> g("a [[[b]]:2]")
|
37 |
+
[[2, 'a '], [10, 'a [[b]]']]
|
38 |
+
>>> g("[(a:2):3]")
|
39 |
+
[[3, ''], [10, '(a:2)']]
|
40 |
+
>>> g("a [b : c : 1] d")
|
41 |
+
[[1, 'a b d'], [10, 'a c d']]
|
42 |
+
>>> g("a[b:[c:d:2]:1]e")
|
43 |
+
[[1, 'abe'], [2, 'ace'], [10, 'ade']]
|
44 |
+
>>> g("a [unbalanced")
|
45 |
+
[[10, 'a [unbalanced']]
|
46 |
+
>>> g("a [b:.5] c")
|
47 |
+
[[5, 'a c'], [10, 'a b c']]
|
48 |
+
>>> g("a [{b|d{:.5] c") # not handling this right now
|
49 |
+
[[5, 'a c'], [10, 'a {b|d{ c']]
|
50 |
+
>>> g("((a][:b:c [d:3]")
|
51 |
+
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
52 |
+
>>> g("[a|(b:1.1)]")
|
53 |
+
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
54 |
+
"""
|
55 |
+
|
56 |
+
def collect_steps(steps, tree):
|
57 |
+
l = [steps]
|
58 |
+
class CollectSteps(lark.Visitor):
|
59 |
+
def scheduled(self, tree):
|
60 |
+
tree.children[-1] = float(tree.children[-1])
|
61 |
+
if tree.children[-1] < 1:
|
62 |
+
tree.children[-1] *= steps
|
63 |
+
tree.children[-1] = min(steps, int(tree.children[-1]))
|
64 |
+
l.append(tree.children[-1])
|
65 |
+
def alternate(self, tree):
|
66 |
+
l.extend(range(1, steps+1))
|
67 |
+
CollectSteps().visit(tree)
|
68 |
+
return sorted(set(l))
|
69 |
+
|
70 |
+
def at_step(step, tree):
|
71 |
+
class AtStep(lark.Transformer):
|
72 |
+
def scheduled(self, args):
|
73 |
+
before, after, _, when = args
|
74 |
+
yield before or () if step <= when else after
|
75 |
+
def alternate(self, args):
|
76 |
+
yield next(args[(step - 1)%len(args)])
|
77 |
+
def start(self, args):
|
78 |
+
def flatten(x):
|
79 |
+
if type(x) == str:
|
80 |
+
yield x
|
81 |
+
else:
|
82 |
+
for gen in x:
|
83 |
+
yield from flatten(gen)
|
84 |
+
return ''.join(flatten(args))
|
85 |
+
def plain(self, args):
|
86 |
+
yield args[0].value
|
87 |
+
def __default__(self, data, children, meta):
|
88 |
+
for child in children:
|
89 |
+
yield child
|
90 |
+
return AtStep().transform(tree)
|
91 |
+
|
92 |
+
def get_schedule(prompt):
|
93 |
+
try:
|
94 |
+
tree = schedule_parser.parse(prompt)
|
95 |
+
except lark.exceptions.LarkError as e:
|
96 |
+
if 0:
|
97 |
+
import traceback
|
98 |
+
traceback.print_exc()
|
99 |
+
return [[steps, prompt]]
|
100 |
+
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
|
101 |
+
|
102 |
+
promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
|
103 |
+
return [promptdict[prompt] for prompt in prompts]
|
104 |
+
|
105 |
+
|
106 |
+
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
107 |
+
|
108 |
+
|
109 |
+
def get_learned_conditioning(model, prompts, steps):
|
110 |
+
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
111 |
+
and the sampling step at which this condition is to be replaced by the next one.
|
112 |
+
|
113 |
+
Input:
|
114 |
+
(model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
|
115 |
+
|
116 |
+
Output:
|
117 |
+
[
|
118 |
+
[
|
119 |
+
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
|
120 |
+
],
|
121 |
+
[
|
122 |
+
ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
|
123 |
+
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
|
124 |
+
]
|
125 |
+
]
|
126 |
+
"""
|
127 |
+
res = []
|
128 |
+
|
129 |
+
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
130 |
+
cache = {}
|
131 |
+
|
132 |
+
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
|
133 |
+
|
134 |
+
cached = cache.get(prompt, None)
|
135 |
+
if cached is not None:
|
136 |
+
res.append(cached)
|
137 |
+
continue
|
138 |
+
|
139 |
+
texts = [x[1] for x in prompt_schedule]
|
140 |
+
conds = model.get_learned_conditioning(texts)
|
141 |
+
|
142 |
+
cond_schedule = []
|
143 |
+
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
144 |
+
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
145 |
+
|
146 |
+
cache[prompt] = cond_schedule
|
147 |
+
res.append(cond_schedule)
|
148 |
+
|
149 |
+
return res
|
150 |
+
|
151 |
+
|
152 |
+
re_AND = re.compile(r"\bAND\b")
|
153 |
+
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
|
154 |
+
|
155 |
+
def get_multicond_prompt_list(prompts):
|
156 |
+
res_indexes = []
|
157 |
+
|
158 |
+
prompt_flat_list = []
|
159 |
+
prompt_indexes = {}
|
160 |
+
|
161 |
+
for prompt in prompts:
|
162 |
+
subprompts = re_AND.split(prompt)
|
163 |
+
|
164 |
+
indexes = []
|
165 |
+
for subprompt in subprompts:
|
166 |
+
match = re_weight.search(subprompt)
|
167 |
+
|
168 |
+
text, weight = match.groups() if match is not None else (subprompt, 1.0)
|
169 |
+
|
170 |
+
weight = float(weight) if weight is not None else 1.0
|
171 |
+
|
172 |
+
index = prompt_indexes.get(text, None)
|
173 |
+
if index is None:
|
174 |
+
index = len(prompt_flat_list)
|
175 |
+
prompt_flat_list.append(text)
|
176 |
+
prompt_indexes[text] = index
|
177 |
+
|
178 |
+
indexes.append((index, weight))
|
179 |
+
|
180 |
+
res_indexes.append(indexes)
|
181 |
+
|
182 |
+
return res_indexes, prompt_flat_list, prompt_indexes
|
183 |
+
|
184 |
+
|
185 |
+
class ComposableScheduledPromptConditioning:
|
186 |
+
def __init__(self, schedules, weight=1.0):
|
187 |
+
self.schedules: List[ScheduledPromptConditioning] = schedules
|
188 |
+
self.weight: float = weight
|
189 |
+
|
190 |
+
|
191 |
+
class MulticondLearnedConditioning:
|
192 |
+
def __init__(self, shape, batch):
|
193 |
+
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
194 |
+
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
195 |
+
|
196 |
+
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
197 |
+
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
198 |
+
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
199 |
+
|
200 |
+
https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
|
201 |
+
"""
|
202 |
+
|
203 |
+
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
204 |
+
|
205 |
+
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
|
206 |
+
|
207 |
+
res = []
|
208 |
+
for indexes in res_indexes:
|
209 |
+
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
|
210 |
+
|
211 |
+
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
212 |
+
|
213 |
+
|
214 |
+
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
|
215 |
+
param = c[0][0].cond
|
216 |
+
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
217 |
+
for i, cond_schedule in enumerate(c):
|
218 |
+
target_index = 0
|
219 |
+
for current, (end_at, cond) in enumerate(cond_schedule):
|
220 |
+
if current_step <= end_at:
|
221 |
+
target_index = current
|
222 |
+
break
|
223 |
+
res[i] = cond_schedule[target_index].cond
|
224 |
+
|
225 |
+
return res
|
226 |
+
|
227 |
+
|
228 |
+
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
229 |
+
param = c.batch[0][0].schedules[0].cond
|
230 |
+
|
231 |
+
tensors = []
|
232 |
+
conds_list = []
|
233 |
+
|
234 |
+
for batch_no, composable_prompts in enumerate(c.batch):
|
235 |
+
conds_for_batch = []
|
236 |
+
|
237 |
+
for cond_index, composable_prompt in enumerate(composable_prompts):
|
238 |
+
target_index = 0
|
239 |
+
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
|
240 |
+
if current_step <= end_at:
|
241 |
+
target_index = current
|
242 |
+
break
|
243 |
+
|
244 |
+
conds_for_batch.append((len(tensors), composable_prompt.weight))
|
245 |
+
tensors.append(composable_prompt.schedules[target_index].cond)
|
246 |
+
|
247 |
+
conds_list.append(conds_for_batch)
|
248 |
+
|
249 |
+
# if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
|
250 |
+
# and won't be able to torch.stack them. So this fixes that.
|
251 |
+
token_count = max([x.shape[0] for x in tensors])
|
252 |
+
for i in range(len(tensors)):
|
253 |
+
if tensors[i].shape[0] != token_count:
|
254 |
+
last_vector = tensors[i][-1:]
|
255 |
+
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
256 |
+
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
257 |
+
|
258 |
+
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|
259 |
+
|
260 |
+
|
261 |
+
re_attention = re.compile(r"""
|
262 |
+
\\\(|
|
263 |
+
\\\)|
|
264 |
+
\\\[|
|
265 |
+
\\]|
|
266 |
+
\\\\|
|
267 |
+
\\|
|
268 |
+
\(|
|
269 |
+
\[|
|
270 |
+
:([+-]?[.\d]+)\)|
|
271 |
+
\)|
|
272 |
+
]|
|
273 |
+
[^\\()\[\]:]+|
|
274 |
+
:
|
275 |
+
""", re.X)
|
276 |
+
|
277 |
+
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
|
278 |
+
|
279 |
+
def parse_prompt_attention(text):
|
280 |
+
"""
|
281 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
282 |
+
Accepted tokens are:
|
283 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
284 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
285 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
286 |
+
\( - literal character '('
|
287 |
+
\[ - literal character '['
|
288 |
+
\) - literal character ')'
|
289 |
+
\] - literal character ']'
|
290 |
+
\\ - literal character '\'
|
291 |
+
anything else - just text
|
292 |
+
|
293 |
+
>>> parse_prompt_attention('normal text')
|
294 |
+
[['normal text', 1.0]]
|
295 |
+
>>> parse_prompt_attention('an (important) word')
|
296 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
297 |
+
>>> parse_prompt_attention('(unbalanced')
|
298 |
+
[['unbalanced', 1.1]]
|
299 |
+
>>> parse_prompt_attention('\(literal\]')
|
300 |
+
[['(literal]', 1.0]]
|
301 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
302 |
+
[['unnecessaryparens', 1.1]]
|
303 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
304 |
+
[['a ', 1.0],
|
305 |
+
['house', 1.5730000000000004],
|
306 |
+
[' ', 1.1],
|
307 |
+
['on', 1.0],
|
308 |
+
[' a ', 1.1],
|
309 |
+
['hill', 0.55],
|
310 |
+
[', sun, ', 1.1],
|
311 |
+
['sky', 1.4641000000000006],
|
312 |
+
['.', 1.1]]
|
313 |
+
"""
|
314 |
+
|
315 |
+
res = []
|
316 |
+
round_brackets = []
|
317 |
+
square_brackets = []
|
318 |
+
|
319 |
+
round_bracket_multiplier = 1.1
|
320 |
+
square_bracket_multiplier = 1 / 1.1
|
321 |
+
|
322 |
+
def multiply_range(start_position, multiplier):
|
323 |
+
for p in range(start_position, len(res)):
|
324 |
+
res[p][1] *= multiplier
|
325 |
+
|
326 |
+
for m in re_attention.finditer(text):
|
327 |
+
text = m.group(0)
|
328 |
+
weight = m.group(1)
|
329 |
+
|
330 |
+
if text.startswith('\\'):
|
331 |
+
res.append([text[1:], 1.0])
|
332 |
+
elif text == '(':
|
333 |
+
round_brackets.append(len(res))
|
334 |
+
elif text == '[':
|
335 |
+
square_brackets.append(len(res))
|
336 |
+
elif weight is not None and len(round_brackets) > 0:
|
337 |
+
multiply_range(round_brackets.pop(), float(weight))
|
338 |
+
elif text == ')' and len(round_brackets) > 0:
|
339 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
340 |
+
elif text == ']' and len(square_brackets) > 0:
|
341 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
342 |
+
else:
|
343 |
+
parts = re.split(re_break, text)
|
344 |
+
for i, part in enumerate(parts):
|
345 |
+
if i > 0:
|
346 |
+
res.append(["BREAK", -1])
|
347 |
+
res.append([part, 1.0])
|
348 |
+
|
349 |
+
for pos in round_brackets:
|
350 |
+
multiply_range(pos, round_bracket_multiplier)
|
351 |
+
|
352 |
+
for pos in square_brackets:
|
353 |
+
multiply_range(pos, square_bracket_multiplier)
|
354 |
+
|
355 |
+
if len(res) == 0:
|
356 |
+
res = [["", 1.0]]
|
357 |
+
|
358 |
+
# merge runs of identical weights
|
359 |
+
i = 0
|
360 |
+
while i + 1 < len(res):
|
361 |
+
if res[i][1] == res[i + 1][1]:
|
362 |
+
res[i][0] += res[i + 1][0]
|
363 |
+
res.pop(i + 1)
|
364 |
+
else:
|
365 |
+
i += 1
|
366 |
+
|
367 |
+
return res
|
368 |
+
|
369 |
+
if __name__ == "__main__":
|
370 |
+
import doctest
|
371 |
+
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
|
372 |
+
else:
|
373 |
+
import torch # doctest faster
|
sd/stable-diffusion-webui/modules/realesrgan_model.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from basicsr.utils.download_util import load_file_from_url
|
8 |
+
from realesrgan import RealESRGANer
|
9 |
+
|
10 |
+
from modules.upscaler import Upscaler, UpscalerData
|
11 |
+
from modules.shared import cmd_opts, opts
|
12 |
+
|
13 |
+
|
14 |
+
class UpscalerRealESRGAN(Upscaler):
|
15 |
+
def __init__(self, path):
|
16 |
+
self.name = "RealESRGAN"
|
17 |
+
self.user_path = path
|
18 |
+
super().__init__()
|
19 |
+
try:
|
20 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
21 |
+
from realesrgan import RealESRGANer
|
22 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
23 |
+
self.enable = True
|
24 |
+
self.scalers = []
|
25 |
+
scalers = self.load_models(path)
|
26 |
+
for scaler in scalers:
|
27 |
+
if scaler.name in opts.realesrgan_enabled_models:
|
28 |
+
self.scalers.append(scaler)
|
29 |
+
|
30 |
+
except Exception:
|
31 |
+
print("Error importing Real-ESRGAN:", file=sys.stderr)
|
32 |
+
print(traceback.format_exc(), file=sys.stderr)
|
33 |
+
self.enable = False
|
34 |
+
self.scalers = []
|
35 |
+
|
36 |
+
def do_upscale(self, img, path):
|
37 |
+
if not self.enable:
|
38 |
+
return img
|
39 |
+
|
40 |
+
info = self.load_model(path)
|
41 |
+
if not os.path.exists(info.local_data_path):
|
42 |
+
print("Unable to load RealESRGAN model: %s" % info.name)
|
43 |
+
return img
|
44 |
+
|
45 |
+
upsampler = RealESRGANer(
|
46 |
+
scale=info.scale,
|
47 |
+
model_path=info.local_data_path,
|
48 |
+
model=info.model(),
|
49 |
+
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
50 |
+
tile=opts.ESRGAN_tile,
|
51 |
+
tile_pad=opts.ESRGAN_tile_overlap,
|
52 |
+
)
|
53 |
+
|
54 |
+
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
55 |
+
|
56 |
+
image = Image.fromarray(upsampled)
|
57 |
+
return image
|
58 |
+
|
59 |
+
def load_model(self, path):
|
60 |
+
try:
|
61 |
+
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
|
62 |
+
|
63 |
+
if info is None:
|
64 |
+
print(f"Unable to find model info: {path}")
|
65 |
+
return None
|
66 |
+
|
67 |
+
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
68 |
+
return info
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
71 |
+
print(traceback.format_exc(), file=sys.stderr)
|
72 |
+
return None
|
73 |
+
|
74 |
+
def load_models(self, _):
|
75 |
+
return get_realesrgan_models(self)
|
76 |
+
|
77 |
+
|
78 |
+
def get_realesrgan_models(scaler):
|
79 |
+
try:
|
80 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
81 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
82 |
+
models = [
|
83 |
+
UpscalerData(
|
84 |
+
name="R-ESRGAN General 4xV3",
|
85 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
86 |
+
scale=4,
|
87 |
+
upscaler=scaler,
|
88 |
+
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
89 |
+
),
|
90 |
+
UpscalerData(
|
91 |
+
name="R-ESRGAN General WDN 4xV3",
|
92 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
93 |
+
scale=4,
|
94 |
+
upscaler=scaler,
|
95 |
+
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
96 |
+
),
|
97 |
+
UpscalerData(
|
98 |
+
name="R-ESRGAN AnimeVideo",
|
99 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
100 |
+
scale=4,
|
101 |
+
upscaler=scaler,
|
102 |
+
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
103 |
+
),
|
104 |
+
UpscalerData(
|
105 |
+
name="R-ESRGAN 4x+",
|
106 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
107 |
+
scale=4,
|
108 |
+
upscaler=scaler,
|
109 |
+
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
110 |
+
),
|
111 |
+
UpscalerData(
|
112 |
+
name="R-ESRGAN 4x+ Anime6B",
|
113 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
114 |
+
scale=4,
|
115 |
+
upscaler=scaler,
|
116 |
+
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
117 |
+
),
|
118 |
+
UpscalerData(
|
119 |
+
name="R-ESRGAN 2x+",
|
120 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
121 |
+
scale=2,
|
122 |
+
upscaler=scaler,
|
123 |
+
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
124 |
+
),
|
125 |
+
]
|
126 |
+
return models
|
127 |
+
except Exception as e:
|
128 |
+
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
129 |
+
print(traceback.format_exc(), file=sys.stderr)
|
sd/stable-diffusion-webui/modules/safe.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this code is adapted from the script contributed by anon from /h/
|
2 |
+
|
3 |
+
import io
|
4 |
+
import pickle
|
5 |
+
import collections
|
6 |
+
import sys
|
7 |
+
import traceback
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import numpy
|
11 |
+
import _codecs
|
12 |
+
import zipfile
|
13 |
+
import re
|
14 |
+
|
15 |
+
|
16 |
+
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
17 |
+
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
18 |
+
|
19 |
+
|
20 |
+
def encode(*args):
|
21 |
+
out = _codecs.encode(*args)
|
22 |
+
return out
|
23 |
+
|
24 |
+
|
25 |
+
class RestrictedUnpickler(pickle.Unpickler):
|
26 |
+
extra_handler = None
|
27 |
+
|
28 |
+
def persistent_load(self, saved_id):
|
29 |
+
assert saved_id[0] == 'storage'
|
30 |
+
return TypedStorage()
|
31 |
+
|
32 |
+
def find_class(self, module, name):
|
33 |
+
if self.extra_handler is not None:
|
34 |
+
res = self.extra_handler(module, name)
|
35 |
+
if res is not None:
|
36 |
+
return res
|
37 |
+
|
38 |
+
if module == 'collections' and name == 'OrderedDict':
|
39 |
+
return getattr(collections, name)
|
40 |
+
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
41 |
+
return getattr(torch._utils, name)
|
42 |
+
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
|
43 |
+
return getattr(torch, name)
|
44 |
+
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
45 |
+
return getattr(torch.nn.modules.container, name)
|
46 |
+
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
|
47 |
+
return getattr(numpy.core.multiarray, name)
|
48 |
+
if module == 'numpy' and name in ['dtype', 'ndarray']:
|
49 |
+
return getattr(numpy, name)
|
50 |
+
if module == '_codecs' and name == 'encode':
|
51 |
+
return encode
|
52 |
+
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
53 |
+
import pytorch_lightning.callbacks
|
54 |
+
return pytorch_lightning.callbacks.model_checkpoint
|
55 |
+
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
56 |
+
import pytorch_lightning.callbacks.model_checkpoint
|
57 |
+
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
58 |
+
if module == "__builtin__" and name == 'set':
|
59 |
+
return set
|
60 |
+
|
61 |
+
# Forbid everything else.
|
62 |
+
raise Exception(f"global '{module}/{name}' is forbidden")
|
63 |
+
|
64 |
+
|
65 |
+
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
|
66 |
+
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
|
67 |
+
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
68 |
+
|
69 |
+
def check_zip_filenames(filename, names):
|
70 |
+
for name in names:
|
71 |
+
if allowed_zip_names_re.match(name):
|
72 |
+
continue
|
73 |
+
|
74 |
+
raise Exception(f"bad file inside {filename}: {name}")
|
75 |
+
|
76 |
+
|
77 |
+
def check_pt(filename, extra_handler):
|
78 |
+
try:
|
79 |
+
|
80 |
+
# new pytorch format is a zip file
|
81 |
+
with zipfile.ZipFile(filename) as z:
|
82 |
+
check_zip_filenames(filename, z.namelist())
|
83 |
+
|
84 |
+
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
85 |
+
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
86 |
+
if len(data_pkl_filenames) == 0:
|
87 |
+
raise Exception(f"data.pkl not found in {filename}")
|
88 |
+
if len(data_pkl_filenames) > 1:
|
89 |
+
raise Exception(f"Multiple data.pkl found in {filename}")
|
90 |
+
with z.open(data_pkl_filenames[0]) as file:
|
91 |
+
unpickler = RestrictedUnpickler(file)
|
92 |
+
unpickler.extra_handler = extra_handler
|
93 |
+
unpickler.load()
|
94 |
+
|
95 |
+
except zipfile.BadZipfile:
|
96 |
+
|
97 |
+
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
98 |
+
with open(filename, "rb") as file:
|
99 |
+
unpickler = RestrictedUnpickler(file)
|
100 |
+
unpickler.extra_handler = extra_handler
|
101 |
+
for i in range(5):
|
102 |
+
unpickler.load()
|
103 |
+
|
104 |
+
|
105 |
+
def load(filename, *args, **kwargs):
|
106 |
+
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
107 |
+
|
108 |
+
|
109 |
+
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
110 |
+
"""
|
111 |
+
this function is intended to be used by extensions that want to load models with
|
112 |
+
some extra classes in them that the usual unpickler would find suspicious.
|
113 |
+
|
114 |
+
Use the extra_handler argument to specify a function that takes module and field name as text,
|
115 |
+
and returns that field's value:
|
116 |
+
|
117 |
+
```python
|
118 |
+
def extra(module, name):
|
119 |
+
if module == 'collections' and name == 'OrderedDict':
|
120 |
+
return collections.OrderedDict
|
121 |
+
|
122 |
+
return None
|
123 |
+
|
124 |
+
safe.load_with_extra('model.pt', extra_handler=extra)
|
125 |
+
```
|
126 |
+
|
127 |
+
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
128 |
+
definitely unsafe.
|
129 |
+
"""
|
130 |
+
|
131 |
+
from modules import shared
|
132 |
+
|
133 |
+
try:
|
134 |
+
if not shared.cmd_opts.disable_safe_unpickle:
|
135 |
+
check_pt(filename, extra_handler)
|
136 |
+
|
137 |
+
except pickle.UnpicklingError:
|
138 |
+
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
139 |
+
print(traceback.format_exc(), file=sys.stderr)
|
140 |
+
print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
|
141 |
+
print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
|
142 |
+
return None
|
143 |
+
|
144 |
+
except Exception:
|
145 |
+
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
146 |
+
print(traceback.format_exc(), file=sys.stderr)
|
147 |
+
print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
148 |
+
print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
149 |
+
return None
|
150 |
+
|
151 |
+
return unsafe_torch_load(filename, *args, **kwargs)
|
152 |
+
|
153 |
+
|
154 |
+
class Extra:
|
155 |
+
"""
|
156 |
+
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
157 |
+
(because it's not your code making the torch.load call). The intended use is like this:
|
158 |
+
|
159 |
+
```
|
160 |
+
import torch
|
161 |
+
from modules import safe
|
162 |
+
|
163 |
+
def handler(module, name):
|
164 |
+
if module == 'torch' and name in ['float64', 'float16']:
|
165 |
+
return getattr(torch, name)
|
166 |
+
|
167 |
+
return None
|
168 |
+
|
169 |
+
with safe.Extra(handler):
|
170 |
+
x = torch.load('model.pt')
|
171 |
+
```
|
172 |
+
"""
|
173 |
+
|
174 |
+
def __init__(self, handler):
|
175 |
+
self.handler = handler
|
176 |
+
|
177 |
+
def __enter__(self):
|
178 |
+
global global_extra_handler
|
179 |
+
|
180 |
+
assert global_extra_handler is None, 'already inside an Extra() block'
|
181 |
+
global_extra_handler = self.handler
|
182 |
+
|
183 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
184 |
+
global global_extra_handler
|
185 |
+
|
186 |
+
global_extra_handler = None
|
187 |
+
|
188 |
+
|
189 |
+
unsafe_torch_load = torch.load
|
190 |
+
torch.load = load
|
191 |
+
global_extra_handler = None
|
192 |
+
|
sd/stable-diffusion-webui/modules/script_callbacks.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import traceback
|
3 |
+
from collections import namedtuple
|
4 |
+
import inspect
|
5 |
+
from typing import Optional, Dict, Any
|
6 |
+
|
7 |
+
from fastapi import FastAPI
|
8 |
+
from gradio import Blocks
|
9 |
+
|
10 |
+
|
11 |
+
def report_exception(c, job):
|
12 |
+
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
13 |
+
print(traceback.format_exc(), file=sys.stderr)
|
14 |
+
|
15 |
+
|
16 |
+
class ImageSaveParams:
|
17 |
+
def __init__(self, image, p, filename, pnginfo):
|
18 |
+
self.image = image
|
19 |
+
"""the PIL image itself"""
|
20 |
+
|
21 |
+
self.p = p
|
22 |
+
"""p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
|
23 |
+
|
24 |
+
self.filename = filename
|
25 |
+
"""name of file that the image would be saved to"""
|
26 |
+
|
27 |
+
self.pnginfo = pnginfo
|
28 |
+
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
29 |
+
|
30 |
+
|
31 |
+
class CFGDenoiserParams:
|
32 |
+
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
|
33 |
+
self.x = x
|
34 |
+
"""Latent image representation in the process of being denoised"""
|
35 |
+
|
36 |
+
self.image_cond = image_cond
|
37 |
+
"""Conditioning image"""
|
38 |
+
|
39 |
+
self.sigma = sigma
|
40 |
+
"""Current sigma noise step value"""
|
41 |
+
|
42 |
+
self.sampling_step = sampling_step
|
43 |
+
"""Current Sampling step number"""
|
44 |
+
|
45 |
+
self.total_sampling_steps = total_sampling_steps
|
46 |
+
"""Total number of sampling steps planned"""
|
47 |
+
|
48 |
+
|
49 |
+
class CFGDenoisedParams:
|
50 |
+
def __init__(self, x, sampling_step, total_sampling_steps):
|
51 |
+
self.x = x
|
52 |
+
"""Latent image representation in the process of being denoised"""
|
53 |
+
|
54 |
+
self.sampling_step = sampling_step
|
55 |
+
"""Current Sampling step number"""
|
56 |
+
|
57 |
+
self.total_sampling_steps = total_sampling_steps
|
58 |
+
"""Total number of sampling steps planned"""
|
59 |
+
|
60 |
+
|
61 |
+
class UiTrainTabParams:
|
62 |
+
def __init__(self, txt2img_preview_params):
|
63 |
+
self.txt2img_preview_params = txt2img_preview_params
|
64 |
+
|
65 |
+
|
66 |
+
class ImageGridLoopParams:
|
67 |
+
def __init__(self, imgs, cols, rows):
|
68 |
+
self.imgs = imgs
|
69 |
+
self.cols = cols
|
70 |
+
self.rows = rows
|
71 |
+
|
72 |
+
|
73 |
+
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
74 |
+
callback_map = dict(
|
75 |
+
callbacks_app_started=[],
|
76 |
+
callbacks_model_loaded=[],
|
77 |
+
callbacks_ui_tabs=[],
|
78 |
+
callbacks_ui_train_tabs=[],
|
79 |
+
callbacks_ui_settings=[],
|
80 |
+
callbacks_before_image_saved=[],
|
81 |
+
callbacks_image_saved=[],
|
82 |
+
callbacks_cfg_denoiser=[],
|
83 |
+
callbacks_cfg_denoised=[],
|
84 |
+
callbacks_before_component=[],
|
85 |
+
callbacks_after_component=[],
|
86 |
+
callbacks_image_grid=[],
|
87 |
+
callbacks_infotext_pasted=[],
|
88 |
+
callbacks_script_unloaded=[],
|
89 |
+
callbacks_before_ui=[],
|
90 |
+
)
|
91 |
+
|
92 |
+
|
93 |
+
def clear_callbacks():
|
94 |
+
for callback_list in callback_map.values():
|
95 |
+
callback_list.clear()
|
96 |
+
|
97 |
+
|
98 |
+
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
99 |
+
for c in callback_map['callbacks_app_started']:
|
100 |
+
try:
|
101 |
+
c.callback(demo, app)
|
102 |
+
except Exception:
|
103 |
+
report_exception(c, 'app_started_callback')
|
104 |
+
|
105 |
+
|
106 |
+
def model_loaded_callback(sd_model):
|
107 |
+
for c in callback_map['callbacks_model_loaded']:
|
108 |
+
try:
|
109 |
+
c.callback(sd_model)
|
110 |
+
except Exception:
|
111 |
+
report_exception(c, 'model_loaded_callback')
|
112 |
+
|
113 |
+
|
114 |
+
def ui_tabs_callback():
|
115 |
+
res = []
|
116 |
+
|
117 |
+
for c in callback_map['callbacks_ui_tabs']:
|
118 |
+
try:
|
119 |
+
res += c.callback() or []
|
120 |
+
except Exception:
|
121 |
+
report_exception(c, 'ui_tabs_callback')
|
122 |
+
|
123 |
+
return res
|
124 |
+
|
125 |
+
|
126 |
+
def ui_train_tabs_callback(params: UiTrainTabParams):
|
127 |
+
for c in callback_map['callbacks_ui_train_tabs']:
|
128 |
+
try:
|
129 |
+
c.callback(params)
|
130 |
+
except Exception:
|
131 |
+
report_exception(c, 'callbacks_ui_train_tabs')
|
132 |
+
|
133 |
+
|
134 |
+
def ui_settings_callback():
|
135 |
+
for c in callback_map['callbacks_ui_settings']:
|
136 |
+
try:
|
137 |
+
c.callback()
|
138 |
+
except Exception:
|
139 |
+
report_exception(c, 'ui_settings_callback')
|
140 |
+
|
141 |
+
|
142 |
+
def before_image_saved_callback(params: ImageSaveParams):
|
143 |
+
for c in callback_map['callbacks_before_image_saved']:
|
144 |
+
try:
|
145 |
+
c.callback(params)
|
146 |
+
except Exception:
|
147 |
+
report_exception(c, 'before_image_saved_callback')
|
148 |
+
|
149 |
+
|
150 |
+
def image_saved_callback(params: ImageSaveParams):
|
151 |
+
for c in callback_map['callbacks_image_saved']:
|
152 |
+
try:
|
153 |
+
c.callback(params)
|
154 |
+
except Exception:
|
155 |
+
report_exception(c, 'image_saved_callback')
|
156 |
+
|
157 |
+
|
158 |
+
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
159 |
+
for c in callback_map['callbacks_cfg_denoiser']:
|
160 |
+
try:
|
161 |
+
c.callback(params)
|
162 |
+
except Exception:
|
163 |
+
report_exception(c, 'cfg_denoiser_callback')
|
164 |
+
|
165 |
+
|
166 |
+
def cfg_denoised_callback(params: CFGDenoisedParams):
|
167 |
+
for c in callback_map['callbacks_cfg_denoised']:
|
168 |
+
try:
|
169 |
+
c.callback(params)
|
170 |
+
except Exception:
|
171 |
+
report_exception(c, 'cfg_denoised_callback')
|
172 |
+
|
173 |
+
|
174 |
+
def before_component_callback(component, **kwargs):
|
175 |
+
for c in callback_map['callbacks_before_component']:
|
176 |
+
try:
|
177 |
+
c.callback(component, **kwargs)
|
178 |
+
except Exception:
|
179 |
+
report_exception(c, 'before_component_callback')
|
180 |
+
|
181 |
+
|
182 |
+
def after_component_callback(component, **kwargs):
|
183 |
+
for c in callback_map['callbacks_after_component']:
|
184 |
+
try:
|
185 |
+
c.callback(component, **kwargs)
|
186 |
+
except Exception:
|
187 |
+
report_exception(c, 'after_component_callback')
|
188 |
+
|
189 |
+
|
190 |
+
def image_grid_callback(params: ImageGridLoopParams):
|
191 |
+
for c in callback_map['callbacks_image_grid']:
|
192 |
+
try:
|
193 |
+
c.callback(params)
|
194 |
+
except Exception:
|
195 |
+
report_exception(c, 'image_grid')
|
196 |
+
|
197 |
+
|
198 |
+
def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
|
199 |
+
for c in callback_map['callbacks_infotext_pasted']:
|
200 |
+
try:
|
201 |
+
c.callback(infotext, params)
|
202 |
+
except Exception:
|
203 |
+
report_exception(c, 'infotext_pasted')
|
204 |
+
|
205 |
+
|
206 |
+
def script_unloaded_callback():
|
207 |
+
for c in reversed(callback_map['callbacks_script_unloaded']):
|
208 |
+
try:
|
209 |
+
c.callback()
|
210 |
+
except Exception:
|
211 |
+
report_exception(c, 'script_unloaded')
|
212 |
+
|
213 |
+
|
214 |
+
def before_ui_callback():
|
215 |
+
for c in reversed(callback_map['callbacks_before_ui']):
|
216 |
+
try:
|
217 |
+
c.callback()
|
218 |
+
except Exception:
|
219 |
+
report_exception(c, 'before_ui')
|
220 |
+
|
221 |
+
|
222 |
+
def add_callback(callbacks, fun):
|
223 |
+
stack = [x for x in inspect.stack() if x.filename != __file__]
|
224 |
+
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
225 |
+
|
226 |
+
callbacks.append(ScriptCallback(filename, fun))
|
227 |
+
|
228 |
+
|
229 |
+
def remove_current_script_callbacks():
|
230 |
+
stack = [x for x in inspect.stack() if x.filename != __file__]
|
231 |
+
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
232 |
+
if filename == 'unknown file':
|
233 |
+
return
|
234 |
+
for callback_list in callback_map.values():
|
235 |
+
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
|
236 |
+
callback_list.remove(callback_to_remove)
|
237 |
+
|
238 |
+
|
239 |
+
def remove_callbacks_for_function(callback_func):
|
240 |
+
for callback_list in callback_map.values():
|
241 |
+
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
|
242 |
+
callback_list.remove(callback_to_remove)
|
243 |
+
|
244 |
+
|
245 |
+
def on_app_started(callback):
|
246 |
+
"""register a function to be called when the webui started, the gradio `Block` component and
|
247 |
+
fastapi `FastAPI` object are passed as the arguments"""
|
248 |
+
add_callback(callback_map['callbacks_app_started'], callback)
|
249 |
+
|
250 |
+
|
251 |
+
def on_model_loaded(callback):
|
252 |
+
"""register a function to be called when the stable diffusion model is created; the model is
|
253 |
+
passed as an argument; this function is also called when the script is reloaded. """
|
254 |
+
add_callback(callback_map['callbacks_model_loaded'], callback)
|
255 |
+
|
256 |
+
|
257 |
+
def on_ui_tabs(callback):
|
258 |
+
"""register a function to be called when the UI is creating new tabs.
|
259 |
+
The function must either return a None, which means no new tabs to be added, or a list, where
|
260 |
+
each element is a tuple:
|
261 |
+
(gradio_component, title, elem_id)
|
262 |
+
|
263 |
+
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
|
264 |
+
title is tab text displayed to user in the UI
|
265 |
+
elem_id is HTML id for the tab
|
266 |
+
"""
|
267 |
+
add_callback(callback_map['callbacks_ui_tabs'], callback)
|
268 |
+
|
269 |
+
|
270 |
+
def on_ui_train_tabs(callback):
|
271 |
+
"""register a function to be called when the UI is creating new tabs for the train tab.
|
272 |
+
Create your new tabs with gr.Tab.
|
273 |
+
"""
|
274 |
+
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
|
275 |
+
|
276 |
+
|
277 |
+
def on_ui_settings(callback):
|
278 |
+
"""register a function to be called before UI settings are populated; add your settings
|
279 |
+
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
280 |
+
add_callback(callback_map['callbacks_ui_settings'], callback)
|
281 |
+
|
282 |
+
|
283 |
+
def on_before_image_saved(callback):
|
284 |
+
"""register a function to be called before an image is saved to a file.
|
285 |
+
The callback is called with one argument:
|
286 |
+
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
287 |
+
"""
|
288 |
+
add_callback(callback_map['callbacks_before_image_saved'], callback)
|
289 |
+
|
290 |
+
|
291 |
+
def on_image_saved(callback):
|
292 |
+
"""register a function to be called after an image is saved to a file.
|
293 |
+
The callback is called with one argument:
|
294 |
+
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
295 |
+
"""
|
296 |
+
add_callback(callback_map['callbacks_image_saved'], callback)
|
297 |
+
|
298 |
+
|
299 |
+
def on_cfg_denoiser(callback):
|
300 |
+
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
301 |
+
The callback is called with one argument:
|
302 |
+
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
|
303 |
+
"""
|
304 |
+
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
|
305 |
+
|
306 |
+
|
307 |
+
def on_cfg_denoised(callback):
|
308 |
+
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
309 |
+
The callback is called with one argument:
|
310 |
+
- params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
|
311 |
+
"""
|
312 |
+
add_callback(callback_map['callbacks_cfg_denoised'], callback)
|
313 |
+
|
314 |
+
|
315 |
+
def on_before_component(callback):
|
316 |
+
"""register a function to be called before a component is created.
|
317 |
+
The callback is called with arguments:
|
318 |
+
- component - gradio component that is about to be created.
|
319 |
+
- **kwargs - args to gradio.components.IOComponent.__init__ function
|
320 |
+
|
321 |
+
Use elem_id/label fields of kwargs to figure out which component it is.
|
322 |
+
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
323 |
+
"""
|
324 |
+
add_callback(callback_map['callbacks_before_component'], callback)
|
325 |
+
|
326 |
+
|
327 |
+
def on_after_component(callback):
|
328 |
+
"""register a function to be called after a component is created. See on_before_component for more."""
|
329 |
+
add_callback(callback_map['callbacks_after_component'], callback)
|
330 |
+
|
331 |
+
|
332 |
+
def on_image_grid(callback):
|
333 |
+
"""register a function to be called before making an image grid.
|
334 |
+
The callback is called with one argument:
|
335 |
+
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
|
336 |
+
"""
|
337 |
+
add_callback(callback_map['callbacks_image_grid'], callback)
|
338 |
+
|
339 |
+
|
340 |
+
def on_infotext_pasted(callback):
|
341 |
+
"""register a function to be called before applying an infotext.
|
342 |
+
The callback is called with two arguments:
|
343 |
+
- infotext: str - raw infotext.
|
344 |
+
- result: Dict[str, any] - parsed infotext parameters.
|
345 |
+
"""
|
346 |
+
add_callback(callback_map['callbacks_infotext_pasted'], callback)
|
347 |
+
|
348 |
+
|
349 |
+
def on_script_unloaded(callback):
|
350 |
+
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
|
351 |
+
the script did should be reverted here"""
|
352 |
+
|
353 |
+
add_callback(callback_map['callbacks_script_unloaded'], callback)
|
354 |
+
|
355 |
+
|
356 |
+
def on_before_ui(callback):
|
357 |
+
"""register a function to be called before the UI is created."""
|
358 |
+
|
359 |
+
add_callback(callback_map['callbacks_before_ui'], callback)
|
sd/stable-diffusion-webui/modules/script_loading.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
import importlib.util
|
5 |
+
from types import ModuleType
|
6 |
+
|
7 |
+
|
8 |
+
def load_module(path):
|
9 |
+
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
|
10 |
+
module = importlib.util.module_from_spec(module_spec)
|
11 |
+
module_spec.loader.exec_module(module)
|
12 |
+
|
13 |
+
return module
|
14 |
+
|
15 |
+
|
16 |
+
def preload_extensions(extensions_dir, parser):
|
17 |
+
if not os.path.isdir(extensions_dir):
|
18 |
+
return
|
19 |
+
|
20 |
+
for dirname in sorted(os.listdir(extensions_dir)):
|
21 |
+
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
|
22 |
+
if not os.path.isfile(preload_script):
|
23 |
+
continue
|
24 |
+
|
25 |
+
try:
|
26 |
+
module = load_module(preload_script)
|
27 |
+
if hasattr(module, 'preload'):
|
28 |
+
module.preload(parser)
|
29 |
+
|
30 |
+
except Exception:
|
31 |
+
print(f"Error running preload() for {preload_script}", file=sys.stderr)
|
32 |
+
print(traceback.format_exc(), file=sys.stderr)
|
sd/stable-diffusion-webui/modules/scripts.py
ADDED
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
from collections import namedtuple
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
|
10 |
+
|
11 |
+
AlwaysVisible = object()
|
12 |
+
|
13 |
+
|
14 |
+
class PostprocessImageArgs:
|
15 |
+
def __init__(self, image):
|
16 |
+
self.image = image
|
17 |
+
|
18 |
+
|
19 |
+
class Script:
|
20 |
+
filename = None
|
21 |
+
args_from = None
|
22 |
+
args_to = None
|
23 |
+
alwayson = False
|
24 |
+
|
25 |
+
is_txt2img = False
|
26 |
+
is_img2img = False
|
27 |
+
|
28 |
+
"""A gr.Group component that has all script's UI inside it"""
|
29 |
+
group = None
|
30 |
+
|
31 |
+
infotext_fields = None
|
32 |
+
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
33 |
+
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
|
34 |
+
"""
|
35 |
+
|
36 |
+
def title(self):
|
37 |
+
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
38 |
+
|
39 |
+
raise NotImplementedError()
|
40 |
+
|
41 |
+
def ui(self, is_img2img):
|
42 |
+
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
|
43 |
+
The return value should be an array of all components that are used in processing.
|
44 |
+
Values of those returned components will be passed to run() and process() functions.
|
45 |
+
"""
|
46 |
+
|
47 |
+
pass
|
48 |
+
|
49 |
+
def show(self, is_img2img):
|
50 |
+
"""
|
51 |
+
is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
|
52 |
+
|
53 |
+
This function should return:
|
54 |
+
- False if the script should not be shown in UI at all
|
55 |
+
- True if the script should be shown in UI if it's selected in the scripts dropdown
|
56 |
+
- script.AlwaysVisible if the script should be shown in UI at all times
|
57 |
+
"""
|
58 |
+
|
59 |
+
return True
|
60 |
+
|
61 |
+
def run(self, p, *args):
|
62 |
+
"""
|
63 |
+
This function is called if the script has been selected in the script dropdown.
|
64 |
+
It must do all processing and return the Processed object with results, same as
|
65 |
+
one returned by processing.process_images.
|
66 |
+
|
67 |
+
Usually the processing is done by calling the processing.process_images function.
|
68 |
+
|
69 |
+
args contains all values returned by components from ui()
|
70 |
+
"""
|
71 |
+
|
72 |
+
pass
|
73 |
+
|
74 |
+
def process(self, p, *args):
|
75 |
+
"""
|
76 |
+
This function is called before processing begins for AlwaysVisible scripts.
|
77 |
+
You can modify the processing object (p) here, inject hooks, etc.
|
78 |
+
args contains all values returned by components from ui()
|
79 |
+
"""
|
80 |
+
|
81 |
+
pass
|
82 |
+
|
83 |
+
def process_batch(self, p, *args, **kwargs):
|
84 |
+
"""
|
85 |
+
Same as process(), but called for every batch.
|
86 |
+
|
87 |
+
**kwargs will have those items:
|
88 |
+
- batch_number - index of current batch, from 0 to number of batches-1
|
89 |
+
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
90 |
+
- seeds - list of seeds for current batch
|
91 |
+
- subseeds - list of subseeds for current batch
|
92 |
+
"""
|
93 |
+
|
94 |
+
pass
|
95 |
+
|
96 |
+
def postprocess_batch(self, p, *args, **kwargs):
|
97 |
+
"""
|
98 |
+
Same as process_batch(), but called for every batch after it has been generated.
|
99 |
+
|
100 |
+
**kwargs will have same items as process_batch, and also:
|
101 |
+
- batch_number - index of current batch, from 0 to number of batches-1
|
102 |
+
- images - torch tensor with all generated images, with values ranging from 0 to 1;
|
103 |
+
"""
|
104 |
+
|
105 |
+
pass
|
106 |
+
|
107 |
+
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
108 |
+
"""
|
109 |
+
Called for every image after it has been generated.
|
110 |
+
"""
|
111 |
+
|
112 |
+
pass
|
113 |
+
|
114 |
+
def postprocess(self, p, processed, *args):
|
115 |
+
"""
|
116 |
+
This function is called after processing ends for AlwaysVisible scripts.
|
117 |
+
args contains all values returned by components from ui()
|
118 |
+
"""
|
119 |
+
|
120 |
+
pass
|
121 |
+
|
122 |
+
def before_component(self, component, **kwargs):
|
123 |
+
"""
|
124 |
+
Called before a component is created.
|
125 |
+
Use elem_id/label fields of kwargs to figure out which component it is.
|
126 |
+
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
127 |
+
You can return created components in the ui() function to add them to the list of arguments for your processing functions
|
128 |
+
"""
|
129 |
+
|
130 |
+
pass
|
131 |
+
|
132 |
+
def after_component(self, component, **kwargs):
|
133 |
+
"""
|
134 |
+
Called after a component is created. Same as above.
|
135 |
+
"""
|
136 |
+
|
137 |
+
pass
|
138 |
+
|
139 |
+
def describe(self):
|
140 |
+
"""unused"""
|
141 |
+
return ""
|
142 |
+
|
143 |
+
def elem_id(self, item_id):
|
144 |
+
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
145 |
+
|
146 |
+
need_tabname = self.show(True) == self.show(False)
|
147 |
+
tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
|
148 |
+
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
149 |
+
|
150 |
+
return f'script_{tabname}{title}_{item_id}'
|
151 |
+
|
152 |
+
|
153 |
+
current_basedir = paths.script_path
|
154 |
+
|
155 |
+
|
156 |
+
def basedir():
|
157 |
+
"""returns the base directory for the current script. For scripts in the main scripts directory,
|
158 |
+
this is the main directory (where webui.py resides), and for scripts in extensions directory
|
159 |
+
(ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
|
160 |
+
"""
|
161 |
+
return current_basedir
|
162 |
+
|
163 |
+
|
164 |
+
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
165 |
+
|
166 |
+
scripts_data = []
|
167 |
+
postprocessing_scripts_data = []
|
168 |
+
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
169 |
+
|
170 |
+
|
171 |
+
def list_scripts(scriptdirname, extension):
|
172 |
+
scripts_list = []
|
173 |
+
|
174 |
+
basedir = os.path.join(paths.script_path, scriptdirname)
|
175 |
+
if os.path.exists(basedir):
|
176 |
+
for filename in sorted(os.listdir(basedir)):
|
177 |
+
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
178 |
+
|
179 |
+
for ext in extensions.active():
|
180 |
+
scripts_list += ext.list_files(scriptdirname, extension)
|
181 |
+
|
182 |
+
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
183 |
+
|
184 |
+
return scripts_list
|
185 |
+
|
186 |
+
|
187 |
+
def list_files_with_name(filename):
|
188 |
+
res = []
|
189 |
+
|
190 |
+
dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
|
191 |
+
|
192 |
+
for dirpath in dirs:
|
193 |
+
if not os.path.isdir(dirpath):
|
194 |
+
continue
|
195 |
+
|
196 |
+
path = os.path.join(dirpath, filename)
|
197 |
+
if os.path.isfile(path):
|
198 |
+
res.append(path)
|
199 |
+
|
200 |
+
return res
|
201 |
+
|
202 |
+
|
203 |
+
def load_scripts():
|
204 |
+
global current_basedir
|
205 |
+
scripts_data.clear()
|
206 |
+
postprocessing_scripts_data.clear()
|
207 |
+
script_callbacks.clear_callbacks()
|
208 |
+
|
209 |
+
scripts_list = list_scripts("scripts", ".py")
|
210 |
+
|
211 |
+
syspath = sys.path
|
212 |
+
|
213 |
+
def register_scripts_from_module(module):
|
214 |
+
for key, script_class in module.__dict__.items():
|
215 |
+
if type(script_class) != type:
|
216 |
+
continue
|
217 |
+
|
218 |
+
if issubclass(script_class, Script):
|
219 |
+
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
220 |
+
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
221 |
+
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
222 |
+
|
223 |
+
for scriptfile in sorted(scripts_list):
|
224 |
+
try:
|
225 |
+
if scriptfile.basedir != paths.script_path:
|
226 |
+
sys.path = [scriptfile.basedir] + sys.path
|
227 |
+
current_basedir = scriptfile.basedir
|
228 |
+
|
229 |
+
script_module = script_loading.load_module(scriptfile.path)
|
230 |
+
register_scripts_from_module(script_module)
|
231 |
+
|
232 |
+
except Exception:
|
233 |
+
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
234 |
+
print(traceback.format_exc(), file=sys.stderr)
|
235 |
+
|
236 |
+
finally:
|
237 |
+
sys.path = syspath
|
238 |
+
current_basedir = paths.script_path
|
239 |
+
|
240 |
+
|
241 |
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
242 |
+
try:
|
243 |
+
res = func(*args, **kwargs)
|
244 |
+
return res
|
245 |
+
except Exception:
|
246 |
+
print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
|
247 |
+
print(traceback.format_exc(), file=sys.stderr)
|
248 |
+
|
249 |
+
return default
|
250 |
+
|
251 |
+
|
252 |
+
class ScriptRunner:
|
253 |
+
def __init__(self):
|
254 |
+
self.scripts = []
|
255 |
+
self.selectable_scripts = []
|
256 |
+
self.alwayson_scripts = []
|
257 |
+
self.titles = []
|
258 |
+
self.infotext_fields = []
|
259 |
+
|
260 |
+
def initialize_scripts(self, is_img2img):
|
261 |
+
from modules import scripts_auto_postprocessing
|
262 |
+
|
263 |
+
self.scripts.clear()
|
264 |
+
self.alwayson_scripts.clear()
|
265 |
+
self.selectable_scripts.clear()
|
266 |
+
|
267 |
+
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
268 |
+
|
269 |
+
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
|
270 |
+
script = script_class()
|
271 |
+
script.filename = path
|
272 |
+
script.is_txt2img = not is_img2img
|
273 |
+
script.is_img2img = is_img2img
|
274 |
+
|
275 |
+
visibility = script.show(script.is_img2img)
|
276 |
+
|
277 |
+
if visibility == AlwaysVisible:
|
278 |
+
self.scripts.append(script)
|
279 |
+
self.alwayson_scripts.append(script)
|
280 |
+
script.alwayson = True
|
281 |
+
|
282 |
+
elif visibility:
|
283 |
+
self.scripts.append(script)
|
284 |
+
self.selectable_scripts.append(script)
|
285 |
+
|
286 |
+
def setup_ui(self):
|
287 |
+
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
288 |
+
|
289 |
+
inputs = [None]
|
290 |
+
inputs_alwayson = [True]
|
291 |
+
|
292 |
+
def create_script_ui(script, inputs, inputs_alwayson):
|
293 |
+
script.args_from = len(inputs)
|
294 |
+
script.args_to = len(inputs)
|
295 |
+
|
296 |
+
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
297 |
+
|
298 |
+
if controls is None:
|
299 |
+
return
|
300 |
+
|
301 |
+
for control in controls:
|
302 |
+
control.custom_script_source = os.path.basename(script.filename)
|
303 |
+
|
304 |
+
if script.infotext_fields is not None:
|
305 |
+
self.infotext_fields += script.infotext_fields
|
306 |
+
|
307 |
+
inputs += controls
|
308 |
+
inputs_alwayson += [script.alwayson for _ in controls]
|
309 |
+
script.args_to = len(inputs)
|
310 |
+
|
311 |
+
for script in self.alwayson_scripts:
|
312 |
+
with gr.Group() as group:
|
313 |
+
create_script_ui(script, inputs, inputs_alwayson)
|
314 |
+
|
315 |
+
script.group = group
|
316 |
+
|
317 |
+
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
|
318 |
+
inputs[0] = dropdown
|
319 |
+
|
320 |
+
for script in self.selectable_scripts:
|
321 |
+
with gr.Group(visible=False) as group:
|
322 |
+
create_script_ui(script, inputs, inputs_alwayson)
|
323 |
+
|
324 |
+
script.group = group
|
325 |
+
|
326 |
+
def select_script(script_index):
|
327 |
+
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
|
328 |
+
|
329 |
+
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
|
330 |
+
|
331 |
+
def init_field(title):
|
332 |
+
"""called when an initial value is set from ui-config.json to show script's UI components"""
|
333 |
+
|
334 |
+
if title == 'None':
|
335 |
+
return
|
336 |
+
|
337 |
+
script_index = self.titles.index(title)
|
338 |
+
self.selectable_scripts[script_index].group.visible = True
|
339 |
+
|
340 |
+
dropdown.init_field = init_field
|
341 |
+
|
342 |
+
dropdown.change(
|
343 |
+
fn=select_script,
|
344 |
+
inputs=[dropdown],
|
345 |
+
outputs=[script.group for script in self.selectable_scripts]
|
346 |
+
)
|
347 |
+
|
348 |
+
self.script_load_ctr = 0
|
349 |
+
def onload_script_visibility(params):
|
350 |
+
title = params.get('Script', None)
|
351 |
+
if title:
|
352 |
+
title_index = self.titles.index(title)
|
353 |
+
visibility = title_index == self.script_load_ctr
|
354 |
+
self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
|
355 |
+
return gr.update(visible=visibility)
|
356 |
+
else:
|
357 |
+
return gr.update(visible=False)
|
358 |
+
|
359 |
+
self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
|
360 |
+
self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
|
361 |
+
|
362 |
+
return inputs
|
363 |
+
|
364 |
+
def run(self, p, *args):
|
365 |
+
script_index = args[0]
|
366 |
+
|
367 |
+
if script_index == 0:
|
368 |
+
return None
|
369 |
+
|
370 |
+
script = self.selectable_scripts[script_index-1]
|
371 |
+
|
372 |
+
if script is None:
|
373 |
+
return None
|
374 |
+
|
375 |
+
script_args = args[script.args_from:script.args_to]
|
376 |
+
processed = script.run(p, *script_args)
|
377 |
+
|
378 |
+
shared.total_tqdm.clear()
|
379 |
+
|
380 |
+
return processed
|
381 |
+
|
382 |
+
def process(self, p):
|
383 |
+
for script in self.alwayson_scripts:
|
384 |
+
try:
|
385 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
386 |
+
script.process(p, *script_args)
|
387 |
+
except Exception:
|
388 |
+
print(f"Error running process: {script.filename}", file=sys.stderr)
|
389 |
+
print(traceback.format_exc(), file=sys.stderr)
|
390 |
+
|
391 |
+
def process_batch(self, p, **kwargs):
|
392 |
+
for script in self.alwayson_scripts:
|
393 |
+
try:
|
394 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
395 |
+
script.process_batch(p, *script_args, **kwargs)
|
396 |
+
except Exception:
|
397 |
+
print(f"Error running process_batch: {script.filename}", file=sys.stderr)
|
398 |
+
print(traceback.format_exc(), file=sys.stderr)
|
399 |
+
|
400 |
+
def postprocess(self, p, processed):
|
401 |
+
for script in self.alwayson_scripts:
|
402 |
+
try:
|
403 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
404 |
+
script.postprocess(p, processed, *script_args)
|
405 |
+
except Exception:
|
406 |
+
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
|
407 |
+
print(traceback.format_exc(), file=sys.stderr)
|
408 |
+
|
409 |
+
def postprocess_batch(self, p, images, **kwargs):
|
410 |
+
for script in self.alwayson_scripts:
|
411 |
+
try:
|
412 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
413 |
+
script.postprocess_batch(p, *script_args, images=images, **kwargs)
|
414 |
+
except Exception:
|
415 |
+
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
416 |
+
print(traceback.format_exc(), file=sys.stderr)
|
417 |
+
|
418 |
+
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
419 |
+
for script in self.alwayson_scripts:
|
420 |
+
try:
|
421 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
422 |
+
script.postprocess_image(p, pp, *script_args)
|
423 |
+
except Exception:
|
424 |
+
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
425 |
+
print(traceback.format_exc(), file=sys.stderr)
|
426 |
+
|
427 |
+
def before_component(self, component, **kwargs):
|
428 |
+
for script in self.scripts:
|
429 |
+
try:
|
430 |
+
script.before_component(component, **kwargs)
|
431 |
+
except Exception:
|
432 |
+
print(f"Error running before_component: {script.filename}", file=sys.stderr)
|
433 |
+
print(traceback.format_exc(), file=sys.stderr)
|
434 |
+
|
435 |
+
def after_component(self, component, **kwargs):
|
436 |
+
for script in self.scripts:
|
437 |
+
try:
|
438 |
+
script.after_component(component, **kwargs)
|
439 |
+
except Exception:
|
440 |
+
print(f"Error running after_component: {script.filename}", file=sys.stderr)
|
441 |
+
print(traceback.format_exc(), file=sys.stderr)
|
442 |
+
|
443 |
+
def reload_sources(self, cache):
|
444 |
+
for si, script in list(enumerate(self.scripts)):
|
445 |
+
args_from = script.args_from
|
446 |
+
args_to = script.args_to
|
447 |
+
filename = script.filename
|
448 |
+
|
449 |
+
module = cache.get(filename, None)
|
450 |
+
if module is None:
|
451 |
+
module = script_loading.load_module(script.filename)
|
452 |
+
cache[filename] = module
|
453 |
+
|
454 |
+
for key, script_class in module.__dict__.items():
|
455 |
+
if type(script_class) == type and issubclass(script_class, Script):
|
456 |
+
self.scripts[si] = script_class()
|
457 |
+
self.scripts[si].filename = filename
|
458 |
+
self.scripts[si].args_from = args_from
|
459 |
+
self.scripts[si].args_to = args_to
|
460 |
+
|
461 |
+
|
462 |
+
scripts_txt2img = ScriptRunner()
|
463 |
+
scripts_img2img = ScriptRunner()
|
464 |
+
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
465 |
+
scripts_current: ScriptRunner = None
|
466 |
+
|
467 |
+
|
468 |
+
def reload_script_body_only():
|
469 |
+
cache = {}
|
470 |
+
scripts_txt2img.reload_sources(cache)
|
471 |
+
scripts_img2img.reload_sources(cache)
|
472 |
+
|
473 |
+
|
474 |
+
def reload_scripts():
|
475 |
+
global scripts_txt2img, scripts_img2img, scripts_postproc
|
476 |
+
|
477 |
+
load_scripts()
|
478 |
+
|
479 |
+
scripts_txt2img = ScriptRunner()
|
480 |
+
scripts_img2img = ScriptRunner()
|
481 |
+
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
482 |
+
|
483 |
+
|
484 |
+
def IOComponent_init(self, *args, **kwargs):
|
485 |
+
if scripts_current is not None:
|
486 |
+
scripts_current.before_component(self, **kwargs)
|
487 |
+
|
488 |
+
script_callbacks.before_component_callback(self, **kwargs)
|
489 |
+
|
490 |
+
res = original_IOComponent_init(self, *args, **kwargs)
|
491 |
+
|
492 |
+
script_callbacks.after_component_callback(self, **kwargs)
|
493 |
+
|
494 |
+
if scripts_current is not None:
|
495 |
+
scripts_current.after_component(self, **kwargs)
|
496 |
+
|
497 |
+
return res
|
498 |
+
|
499 |
+
|
500 |
+
original_IOComponent_init = gr.components.IOComponent.__init__
|
501 |
+
gr.components.IOComponent.__init__ = IOComponent_init
|
sd/stable-diffusion-webui/modules/scripts_auto_postprocessing.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules import scripts, scripts_postprocessing, shared
|
2 |
+
|
3 |
+
|
4 |
+
class ScriptPostprocessingForMainUI(scripts.Script):
|
5 |
+
def __init__(self, script_postproc):
|
6 |
+
self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
|
7 |
+
self.postprocessing_controls = None
|
8 |
+
|
9 |
+
def title(self):
|
10 |
+
return self.script.name
|
11 |
+
|
12 |
+
def show(self, is_img2img):
|
13 |
+
return scripts.AlwaysVisible
|
14 |
+
|
15 |
+
def ui(self, is_img2img):
|
16 |
+
self.postprocessing_controls = self.script.ui()
|
17 |
+
return self.postprocessing_controls.values()
|
18 |
+
|
19 |
+
def postprocess_image(self, p, script_pp, *args):
|
20 |
+
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
|
21 |
+
|
22 |
+
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
23 |
+
pp.info = {}
|
24 |
+
self.script.process(pp, **args_dict)
|
25 |
+
p.extra_generation_params.update(pp.info)
|
26 |
+
script_pp.image = pp.image
|
27 |
+
|
28 |
+
|
29 |
+
def create_auto_preprocessing_script_data():
|
30 |
+
from modules import scripts
|
31 |
+
|
32 |
+
res = []
|
33 |
+
|
34 |
+
for name in shared.opts.postprocessing_enable_in_main_ui:
|
35 |
+
script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
|
36 |
+
if script is None:
|
37 |
+
continue
|
38 |
+
|
39 |
+
constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
|
40 |
+
res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
|
41 |
+
|
42 |
+
return res
|
sd/stable-diffusion-webui/modules/scripts_postprocessing.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from modules import errors, shared
|
5 |
+
|
6 |
+
|
7 |
+
class PostprocessedImage:
|
8 |
+
def __init__(self, image):
|
9 |
+
self.image = image
|
10 |
+
self.info = {}
|
11 |
+
|
12 |
+
|
13 |
+
class ScriptPostprocessing:
|
14 |
+
filename = None
|
15 |
+
controls = None
|
16 |
+
args_from = None
|
17 |
+
args_to = None
|
18 |
+
|
19 |
+
order = 1000
|
20 |
+
"""scripts will be ordred by this value in postprocessing UI"""
|
21 |
+
|
22 |
+
name = None
|
23 |
+
"""this function should return the title of the script."""
|
24 |
+
|
25 |
+
group = None
|
26 |
+
"""A gr.Group component that has all script's UI inside it"""
|
27 |
+
|
28 |
+
def ui(self):
|
29 |
+
"""
|
30 |
+
This function should create gradio UI elements. See https://gradio.app/docs/#components
|
31 |
+
The return value should be a dictionary that maps parameter names to components used in processing.
|
32 |
+
Values of those components will be passed to process() function.
|
33 |
+
"""
|
34 |
+
|
35 |
+
pass
|
36 |
+
|
37 |
+
def process(self, pp: PostprocessedImage, **args):
|
38 |
+
"""
|
39 |
+
This function is called to postprocess the image.
|
40 |
+
args contains a dictionary with all values returned by components from ui()
|
41 |
+
"""
|
42 |
+
|
43 |
+
pass
|
44 |
+
|
45 |
+
def image_changed(self):
|
46 |
+
pass
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
52 |
+
try:
|
53 |
+
res = func(*args, **kwargs)
|
54 |
+
return res
|
55 |
+
except Exception as e:
|
56 |
+
errors.display(e, f"calling {filename}/{funcname}")
|
57 |
+
|
58 |
+
return default
|
59 |
+
|
60 |
+
|
61 |
+
class ScriptPostprocessingRunner:
|
62 |
+
def __init__(self):
|
63 |
+
self.scripts = None
|
64 |
+
self.ui_created = False
|
65 |
+
|
66 |
+
def initialize_scripts(self, scripts_data):
|
67 |
+
self.scripts = []
|
68 |
+
|
69 |
+
for script_class, path, basedir, script_module in scripts_data:
|
70 |
+
script: ScriptPostprocessing = script_class()
|
71 |
+
script.filename = path
|
72 |
+
|
73 |
+
if script.name == "Simple Upscale":
|
74 |
+
continue
|
75 |
+
|
76 |
+
self.scripts.append(script)
|
77 |
+
|
78 |
+
def create_script_ui(self, script, inputs):
|
79 |
+
script.args_from = len(inputs)
|
80 |
+
script.args_to = len(inputs)
|
81 |
+
|
82 |
+
script.controls = wrap_call(script.ui, script.filename, "ui")
|
83 |
+
|
84 |
+
for control in script.controls.values():
|
85 |
+
control.custom_script_source = os.path.basename(script.filename)
|
86 |
+
|
87 |
+
inputs += list(script.controls.values())
|
88 |
+
script.args_to = len(inputs)
|
89 |
+
|
90 |
+
def scripts_in_preferred_order(self):
|
91 |
+
if self.scripts is None:
|
92 |
+
import modules.scripts
|
93 |
+
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
|
94 |
+
|
95 |
+
scripts_order = shared.opts.postprocessing_operation_order
|
96 |
+
|
97 |
+
def script_score(name):
|
98 |
+
for i, possible_match in enumerate(scripts_order):
|
99 |
+
if possible_match == name:
|
100 |
+
return i
|
101 |
+
|
102 |
+
return len(self.scripts)
|
103 |
+
|
104 |
+
script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
|
105 |
+
|
106 |
+
return sorted(self.scripts, key=lambda x: script_scores[x.name])
|
107 |
+
|
108 |
+
def setup_ui(self):
|
109 |
+
inputs = []
|
110 |
+
|
111 |
+
for script in self.scripts_in_preferred_order():
|
112 |
+
with gr.Box() as group:
|
113 |
+
self.create_script_ui(script, inputs)
|
114 |
+
|
115 |
+
script.group = group
|
116 |
+
|
117 |
+
self.ui_created = True
|
118 |
+
return inputs
|
119 |
+
|
120 |
+
def run(self, pp: PostprocessedImage, args):
|
121 |
+
for script in self.scripts_in_preferred_order():
|
122 |
+
shared.state.job = script.name
|
123 |
+
|
124 |
+
script_args = args[script.args_from:script.args_to]
|
125 |
+
|
126 |
+
process_args = {}
|
127 |
+
for (name, component), value in zip(script.controls.items(), script_args):
|
128 |
+
process_args[name] = value
|
129 |
+
|
130 |
+
script.process(pp, **process_args)
|
131 |
+
|
132 |
+
def create_args_for_run(self, scripts_args):
|
133 |
+
if not self.ui_created:
|
134 |
+
with gr.Blocks(analytics_enabled=False):
|
135 |
+
self.setup_ui()
|
136 |
+
|
137 |
+
scripts = self.scripts_in_preferred_order()
|
138 |
+
args = [None] * max([x.args_to for x in scripts])
|
139 |
+
|
140 |
+
for script in scripts:
|
141 |
+
script_args_dict = scripts_args.get(script.name, None)
|
142 |
+
if script_args_dict is not None:
|
143 |
+
|
144 |
+
for i, name in enumerate(script.controls):
|
145 |
+
args[script.args_from + i] = script_args_dict.get(name, None)
|
146 |
+
|
147 |
+
return args
|
148 |
+
|
149 |
+
def image_changed(self):
|
150 |
+
for script in self.scripts_in_preferred_order():
|
151 |
+
script.image_changed()
|
152 |
+
|
sd/stable-diffusion-webui/modules/sd_disable_initialization.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ldm.modules.encoders.modules
|
2 |
+
import open_clip
|
3 |
+
import torch
|
4 |
+
import transformers.utils.hub
|
5 |
+
|
6 |
+
|
7 |
+
class DisableInitialization:
|
8 |
+
"""
|
9 |
+
When an object of this class enters a `with` block, it starts:
|
10 |
+
- preventing torch's layer initialization functions from working
|
11 |
+
- changes CLIP and OpenCLIP to not download model weights
|
12 |
+
- changes CLIP to not make requests to check if there is a new version of a file you already have
|
13 |
+
|
14 |
+
When it leaves the block, it reverts everything to how it was before.
|
15 |
+
|
16 |
+
Use it like this:
|
17 |
+
```
|
18 |
+
with DisableInitialization():
|
19 |
+
do_things()
|
20 |
+
```
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, disable_clip=True):
|
24 |
+
self.replaced = []
|
25 |
+
self.disable_clip = disable_clip
|
26 |
+
|
27 |
+
def replace(self, obj, field, func):
|
28 |
+
original = getattr(obj, field, None)
|
29 |
+
if original is None:
|
30 |
+
return None
|
31 |
+
|
32 |
+
self.replaced.append((obj, field, original))
|
33 |
+
setattr(obj, field, func)
|
34 |
+
|
35 |
+
return original
|
36 |
+
|
37 |
+
def __enter__(self):
|
38 |
+
def do_nothing(*args, **kwargs):
|
39 |
+
pass
|
40 |
+
|
41 |
+
def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
|
42 |
+
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
|
43 |
+
|
44 |
+
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
45 |
+
res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
46 |
+
res.name_or_path = pretrained_model_name_or_path
|
47 |
+
return res
|
48 |
+
|
49 |
+
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
|
50 |
+
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
|
51 |
+
return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
|
52 |
+
|
53 |
+
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
54 |
+
|
55 |
+
# this file is always 404, prevent making request
|
56 |
+
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
57 |
+
return None
|
58 |
+
|
59 |
+
try:
|
60 |
+
res = original(url, *args, local_files_only=True, **kwargs)
|
61 |
+
if res is None:
|
62 |
+
res = original(url, *args, local_files_only=False, **kwargs)
|
63 |
+
return res
|
64 |
+
except Exception as e:
|
65 |
+
return original(url, *args, local_files_only=False, **kwargs)
|
66 |
+
|
67 |
+
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
68 |
+
return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
|
69 |
+
|
70 |
+
def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
|
71 |
+
return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
|
72 |
+
|
73 |
+
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
74 |
+
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
75 |
+
|
76 |
+
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
77 |
+
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
78 |
+
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
79 |
+
|
80 |
+
if self.disable_clip:
|
81 |
+
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
82 |
+
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
83 |
+
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
84 |
+
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
85 |
+
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
86 |
+
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
87 |
+
|
88 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
89 |
+
for obj, field, original in self.replaced:
|
90 |
+
setattr(obj, field, original)
|
91 |
+
|
92 |
+
self.replaced.clear()
|
93 |
+
|
sd/stable-diffusion-webui/modules/sd_hijack.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn.functional import silu
|
3 |
+
from types import MethodType
|
4 |
+
|
5 |
+
import modules.textual_inversion.textual_inversion
|
6 |
+
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
7 |
+
from modules.hypernetworks import hypernetwork
|
8 |
+
from modules.shared import cmd_opts
|
9 |
+
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
10 |
+
|
11 |
+
import ldm.modules.attention
|
12 |
+
import ldm.modules.diffusionmodules.model
|
13 |
+
import ldm.modules.diffusionmodules.openaimodel
|
14 |
+
import ldm.models.diffusion.ddim
|
15 |
+
import ldm.models.diffusion.plms
|
16 |
+
import ldm.modules.encoders.modules
|
17 |
+
|
18 |
+
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
19 |
+
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
20 |
+
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
21 |
+
|
22 |
+
# new memory efficient cross attention blocks do not support hypernets and we already
|
23 |
+
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
24 |
+
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
25 |
+
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
26 |
+
|
27 |
+
# silence new console spam from SD2
|
28 |
+
ldm.modules.attention.print = lambda *args: None
|
29 |
+
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
30 |
+
|
31 |
+
|
32 |
+
def apply_optimizations():
|
33 |
+
undo_optimizations()
|
34 |
+
|
35 |
+
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
36 |
+
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
37 |
+
|
38 |
+
optimization_method = None
|
39 |
+
|
40 |
+
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
41 |
+
print("Applying xformers cross attention optimization.")
|
42 |
+
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
43 |
+
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
44 |
+
optimization_method = 'xformers'
|
45 |
+
elif cmd_opts.opt_sub_quad_attention:
|
46 |
+
print("Applying sub-quadratic cross attention optimization.")
|
47 |
+
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|
48 |
+
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
|
49 |
+
optimization_method = 'sub-quadratic'
|
50 |
+
elif cmd_opts.opt_split_attention_v1:
|
51 |
+
print("Applying v1 cross attention optimization.")
|
52 |
+
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
53 |
+
optimization_method = 'V1'
|
54 |
+
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
|
55 |
+
print("Applying cross attention optimization (InvokeAI).")
|
56 |
+
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
57 |
+
optimization_method = 'InvokeAI'
|
58 |
+
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
59 |
+
print("Applying cross attention optimization (Doggettx).")
|
60 |
+
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
61 |
+
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
62 |
+
optimization_method = 'Doggettx'
|
63 |
+
|
64 |
+
return optimization_method
|
65 |
+
|
66 |
+
|
67 |
+
def undo_optimizations():
|
68 |
+
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
69 |
+
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
70 |
+
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
71 |
+
|
72 |
+
|
73 |
+
def fix_checkpoint():
|
74 |
+
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
75 |
+
checkpoints to be added when not training (there's a warning)"""
|
76 |
+
|
77 |
+
pass
|
78 |
+
|
79 |
+
|
80 |
+
def weighted_loss(sd_model, pred, target, mean=True):
|
81 |
+
#Calculate the weight normally, but ignore the mean
|
82 |
+
loss = sd_model._old_get_loss(pred, target, mean=False)
|
83 |
+
|
84 |
+
#Check if we have weights available
|
85 |
+
weight = getattr(sd_model, '_custom_loss_weight', None)
|
86 |
+
if weight is not None:
|
87 |
+
loss *= weight
|
88 |
+
|
89 |
+
#Return the loss, as mean if specified
|
90 |
+
return loss.mean() if mean else loss
|
91 |
+
|
92 |
+
def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
93 |
+
try:
|
94 |
+
#Temporarily append weights to a place accessible during loss calc
|
95 |
+
sd_model._custom_loss_weight = w
|
96 |
+
|
97 |
+
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
98 |
+
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
99 |
+
if not hasattr(sd_model, '_old_get_loss'):
|
100 |
+
sd_model._old_get_loss = sd_model.get_loss
|
101 |
+
sd_model.get_loss = MethodType(weighted_loss, sd_model)
|
102 |
+
|
103 |
+
#Run the standard forward function, but with the patched 'get_loss'
|
104 |
+
return sd_model.forward(x, c, *args, **kwargs)
|
105 |
+
finally:
|
106 |
+
try:
|
107 |
+
#Delete temporary weights if appended
|
108 |
+
del sd_model._custom_loss_weight
|
109 |
+
except AttributeError as e:
|
110 |
+
pass
|
111 |
+
|
112 |
+
#If we have an old loss function, reset the loss function to the original one
|
113 |
+
if hasattr(sd_model, '_old_get_loss'):
|
114 |
+
sd_model.get_loss = sd_model._old_get_loss
|
115 |
+
del sd_model._old_get_loss
|
116 |
+
|
117 |
+
def apply_weighted_forward(sd_model):
|
118 |
+
#Add new function 'weighted_forward' that can be called to calc weighted loss
|
119 |
+
sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
|
120 |
+
|
121 |
+
def undo_weighted_forward(sd_model):
|
122 |
+
try:
|
123 |
+
del sd_model.weighted_forward
|
124 |
+
except AttributeError as e:
|
125 |
+
pass
|
126 |
+
|
127 |
+
|
128 |
+
class StableDiffusionModelHijack:
|
129 |
+
fixes = None
|
130 |
+
comments = []
|
131 |
+
layers = None
|
132 |
+
circular_enabled = False
|
133 |
+
clip = None
|
134 |
+
optimization_method = None
|
135 |
+
|
136 |
+
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
137 |
+
|
138 |
+
def __init__(self):
|
139 |
+
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
140 |
+
|
141 |
+
def hijack(self, m):
|
142 |
+
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
143 |
+
model_embeddings = m.cond_stage_model.roberta.embeddings
|
144 |
+
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
145 |
+
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
146 |
+
|
147 |
+
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
148 |
+
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
149 |
+
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
150 |
+
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
151 |
+
|
152 |
+
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
153 |
+
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
154 |
+
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
155 |
+
|
156 |
+
apply_weighted_forward(m)
|
157 |
+
if m.cond_stage_key == "edit":
|
158 |
+
sd_hijack_unet.hijack_ddpm_edit()
|
159 |
+
|
160 |
+
self.optimization_method = apply_optimizations()
|
161 |
+
|
162 |
+
self.clip = m.cond_stage_model
|
163 |
+
|
164 |
+
def flatten(el):
|
165 |
+
flattened = [flatten(children) for children in el.children()]
|
166 |
+
res = [el]
|
167 |
+
for c in flattened:
|
168 |
+
res += c
|
169 |
+
return res
|
170 |
+
|
171 |
+
self.layers = flatten(m)
|
172 |
+
|
173 |
+
def undo_hijack(self, m):
|
174 |
+
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
175 |
+
m.cond_stage_model = m.cond_stage_model.wrapped
|
176 |
+
|
177 |
+
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
178 |
+
m.cond_stage_model = m.cond_stage_model.wrapped
|
179 |
+
|
180 |
+
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
181 |
+
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
182 |
+
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
183 |
+
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
|
184 |
+
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
|
185 |
+
m.cond_stage_model = m.cond_stage_model.wrapped
|
186 |
+
|
187 |
+
undo_optimizations()
|
188 |
+
undo_weighted_forward(m)
|
189 |
+
|
190 |
+
self.apply_circular(False)
|
191 |
+
self.layers = None
|
192 |
+
self.clip = None
|
193 |
+
|
194 |
+
def apply_circular(self, enable):
|
195 |
+
if self.circular_enabled == enable:
|
196 |
+
return
|
197 |
+
|
198 |
+
self.circular_enabled = enable
|
199 |
+
|
200 |
+
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
|
201 |
+
layer.padding_mode = 'circular' if enable else 'zeros'
|
202 |
+
|
203 |
+
def clear_comments(self):
|
204 |
+
self.comments = []
|
205 |
+
|
206 |
+
def get_prompt_lengths(self, text):
|
207 |
+
_, token_count = self.clip.process_texts([text])
|
208 |
+
|
209 |
+
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
210 |
+
|
211 |
+
|
212 |
+
class EmbeddingsWithFixes(torch.nn.Module):
|
213 |
+
def __init__(self, wrapped, embeddings):
|
214 |
+
super().__init__()
|
215 |
+
self.wrapped = wrapped
|
216 |
+
self.embeddings = embeddings
|
217 |
+
|
218 |
+
def forward(self, input_ids):
|
219 |
+
batch_fixes = self.embeddings.fixes
|
220 |
+
self.embeddings.fixes = None
|
221 |
+
|
222 |
+
inputs_embeds = self.wrapped(input_ids)
|
223 |
+
|
224 |
+
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
225 |
+
return inputs_embeds
|
226 |
+
|
227 |
+
vecs = []
|
228 |
+
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
229 |
+
for offset, embedding in fixes:
|
230 |
+
emb = devices.cond_cast_unet(embedding.vec)
|
231 |
+
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
232 |
+
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
233 |
+
|
234 |
+
vecs.append(tensor)
|
235 |
+
|
236 |
+
return torch.stack(vecs)
|
237 |
+
|
238 |
+
|
239 |
+
def add_circular_option_to_conv_2d():
|
240 |
+
conv2d_constructor = torch.nn.Conv2d.__init__
|
241 |
+
|
242 |
+
def conv2d_constructor_circular(self, *args, **kwargs):
|
243 |
+
return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
244 |
+
|
245 |
+
torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
246 |
+
|
247 |
+
|
248 |
+
model_hijack = StableDiffusionModelHijack()
|
249 |
+
|
250 |
+
|
251 |
+
def register_buffer(self, name, attr):
|
252 |
+
"""
|
253 |
+
Fix register buffer bug for Mac OS.
|
254 |
+
"""
|
255 |
+
|
256 |
+
if type(attr) == torch.Tensor:
|
257 |
+
if attr.device != devices.device:
|
258 |
+
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
259 |
+
|
260 |
+
setattr(self, name, attr)
|
261 |
+
|
262 |
+
|
263 |
+
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
264 |
+
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
sd/stable-diffusion-webui/modules/sd_hijack_checkpoint.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.checkpoint import checkpoint
|
2 |
+
|
3 |
+
import ldm.modules.attention
|
4 |
+
import ldm.modules.diffusionmodules.openaimodel
|
5 |
+
|
6 |
+
|
7 |
+
def BasicTransformerBlock_forward(self, x, context=None):
|
8 |
+
return checkpoint(self._forward, x, context)
|
9 |
+
|
10 |
+
|
11 |
+
def AttentionBlock_forward(self, x):
|
12 |
+
return checkpoint(self._forward, x)
|
13 |
+
|
14 |
+
|
15 |
+
def ResBlock_forward(self, x, emb):
|
16 |
+
return checkpoint(self._forward, x, emb)
|
17 |
+
|
18 |
+
|
19 |
+
stored = []
|
20 |
+
|
21 |
+
|
22 |
+
def add():
|
23 |
+
if len(stored) != 0:
|
24 |
+
return
|
25 |
+
|
26 |
+
stored.extend([
|
27 |
+
ldm.modules.attention.BasicTransformerBlock.forward,
|
28 |
+
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
|
29 |
+
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
|
30 |
+
])
|
31 |
+
|
32 |
+
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
|
33 |
+
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
|
34 |
+
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
|
35 |
+
|
36 |
+
|
37 |
+
def remove():
|
38 |
+
if len(stored) == 0:
|
39 |
+
return
|
40 |
+
|
41 |
+
ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
|
42 |
+
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
|
43 |
+
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
|
44 |
+
|
45 |
+
stored.clear()
|
46 |
+
|
sd/stable-diffusion-webui/modules/sd_hijack_clip.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from collections import namedtuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from modules import prompt_parser, devices, sd_hijack
|
7 |
+
from modules.shared import opts
|
8 |
+
|
9 |
+
|
10 |
+
class PromptChunk:
|
11 |
+
"""
|
12 |
+
This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
|
13 |
+
If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
|
14 |
+
Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
|
15 |
+
so just 75 tokens from prompt.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self):
|
19 |
+
self.tokens = []
|
20 |
+
self.multipliers = []
|
21 |
+
self.fixes = []
|
22 |
+
|
23 |
+
|
24 |
+
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
25 |
+
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
|
26 |
+
chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
|
27 |
+
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
28 |
+
|
29 |
+
|
30 |
+
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
31 |
+
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
|
32 |
+
have unlimited prompt length and assign weights to tokens in prompt.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, wrapped, hijack):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.wrapped = wrapped
|
39 |
+
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
|
40 |
+
depending on model."""
|
41 |
+
|
42 |
+
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
43 |
+
self.chunk_length = 75
|
44 |
+
|
45 |
+
def empty_chunk(self):
|
46 |
+
"""creates an empty PromptChunk and returns it"""
|
47 |
+
|
48 |
+
chunk = PromptChunk()
|
49 |
+
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
50 |
+
chunk.multipliers = [1.0] * (self.chunk_length + 2)
|
51 |
+
return chunk
|
52 |
+
|
53 |
+
def get_target_prompt_token_count(self, token_count):
|
54 |
+
"""returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
|
55 |
+
|
56 |
+
return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
|
57 |
+
|
58 |
+
def tokenize(self, texts):
|
59 |
+
"""Converts a batch of texts into a batch of token ids"""
|
60 |
+
|
61 |
+
raise NotImplementedError
|
62 |
+
|
63 |
+
def encode_with_transformers(self, tokens):
|
64 |
+
"""
|
65 |
+
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
|
66 |
+
All python lists with tokens are assumed to have same length, usually 77.
|
67 |
+
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
|
68 |
+
model - can be 768 and 1024.
|
69 |
+
Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
|
70 |
+
"""
|
71 |
+
|
72 |
+
raise NotImplementedError
|
73 |
+
|
74 |
+
def encode_embedding_init_text(self, init_text, nvpt):
|
75 |
+
"""Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
|
76 |
+
transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
|
77 |
+
|
78 |
+
raise NotImplementedError
|
79 |
+
|
80 |
+
def tokenize_line(self, line):
|
81 |
+
"""
|
82 |
+
this transforms a single prompt into a list of PromptChunk objects - as many as needed to
|
83 |
+
represent the prompt.
|
84 |
+
Returns the list and the total number of tokens in the prompt.
|
85 |
+
"""
|
86 |
+
|
87 |
+
if opts.enable_emphasis:
|
88 |
+
parsed = prompt_parser.parse_prompt_attention(line)
|
89 |
+
else:
|
90 |
+
parsed = [[line, 1.0]]
|
91 |
+
|
92 |
+
tokenized = self.tokenize([text for text, _ in parsed])
|
93 |
+
|
94 |
+
chunks = []
|
95 |
+
chunk = PromptChunk()
|
96 |
+
token_count = 0
|
97 |
+
last_comma = -1
|
98 |
+
|
99 |
+
def next_chunk(is_last=False):
|
100 |
+
"""puts current chunk into the list of results and produces the next one - empty;
|
101 |
+
if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
|
102 |
+
nonlocal token_count
|
103 |
+
nonlocal last_comma
|
104 |
+
nonlocal chunk
|
105 |
+
|
106 |
+
if is_last:
|
107 |
+
token_count += len(chunk.tokens)
|
108 |
+
else:
|
109 |
+
token_count += self.chunk_length
|
110 |
+
|
111 |
+
to_add = self.chunk_length - len(chunk.tokens)
|
112 |
+
if to_add > 0:
|
113 |
+
chunk.tokens += [self.id_end] * to_add
|
114 |
+
chunk.multipliers += [1.0] * to_add
|
115 |
+
|
116 |
+
chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
|
117 |
+
chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
|
118 |
+
|
119 |
+
last_comma = -1
|
120 |
+
chunks.append(chunk)
|
121 |
+
chunk = PromptChunk()
|
122 |
+
|
123 |
+
for tokens, (text, weight) in zip(tokenized, parsed):
|
124 |
+
if text == 'BREAK' and weight == -1:
|
125 |
+
next_chunk()
|
126 |
+
continue
|
127 |
+
|
128 |
+
position = 0
|
129 |
+
while position < len(tokens):
|
130 |
+
token = tokens[position]
|
131 |
+
|
132 |
+
if token == self.comma_token:
|
133 |
+
last_comma = len(chunk.tokens)
|
134 |
+
|
135 |
+
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
136 |
+
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
|
137 |
+
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
138 |
+
break_location = last_comma + 1
|
139 |
+
|
140 |
+
reloc_tokens = chunk.tokens[break_location:]
|
141 |
+
reloc_mults = chunk.multipliers[break_location:]
|
142 |
+
|
143 |
+
chunk.tokens = chunk.tokens[:break_location]
|
144 |
+
chunk.multipliers = chunk.multipliers[:break_location]
|
145 |
+
|
146 |
+
next_chunk()
|
147 |
+
chunk.tokens = reloc_tokens
|
148 |
+
chunk.multipliers = reloc_mults
|
149 |
+
|
150 |
+
if len(chunk.tokens) == self.chunk_length:
|
151 |
+
next_chunk()
|
152 |
+
|
153 |
+
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
|
154 |
+
if embedding is None:
|
155 |
+
chunk.tokens.append(token)
|
156 |
+
chunk.multipliers.append(weight)
|
157 |
+
position += 1
|
158 |
+
continue
|
159 |
+
|
160 |
+
emb_len = int(embedding.vec.shape[0])
|
161 |
+
if len(chunk.tokens) + emb_len > self.chunk_length:
|
162 |
+
next_chunk()
|
163 |
+
|
164 |
+
chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
|
165 |
+
|
166 |
+
chunk.tokens += [0] * emb_len
|
167 |
+
chunk.multipliers += [weight] * emb_len
|
168 |
+
position += embedding_length_in_tokens
|
169 |
+
|
170 |
+
if len(chunk.tokens) > 0 or len(chunks) == 0:
|
171 |
+
next_chunk(is_last=True)
|
172 |
+
|
173 |
+
return chunks, token_count
|
174 |
+
|
175 |
+
def process_texts(self, texts):
|
176 |
+
"""
|
177 |
+
Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
|
178 |
+
length, in tokens, of all texts.
|
179 |
+
"""
|
180 |
+
|
181 |
+
token_count = 0
|
182 |
+
|
183 |
+
cache = {}
|
184 |
+
batch_chunks = []
|
185 |
+
for line in texts:
|
186 |
+
if line in cache:
|
187 |
+
chunks = cache[line]
|
188 |
+
else:
|
189 |
+
chunks, current_token_count = self.tokenize_line(line)
|
190 |
+
token_count = max(current_token_count, token_count)
|
191 |
+
|
192 |
+
cache[line] = chunks
|
193 |
+
|
194 |
+
batch_chunks.append(chunks)
|
195 |
+
|
196 |
+
return batch_chunks, token_count
|
197 |
+
|
198 |
+
def forward(self, texts):
|
199 |
+
"""
|
200 |
+
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
201 |
+
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
202 |
+
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
|
203 |
+
An example shape returned by this function can be: (2, 77, 768).
|
204 |
+
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
205 |
+
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
206 |
+
"""
|
207 |
+
|
208 |
+
if opts.use_old_emphasis_implementation:
|
209 |
+
import modules.sd_hijack_clip_old
|
210 |
+
return modules.sd_hijack_clip_old.forward_old(self, texts)
|
211 |
+
|
212 |
+
batch_chunks, token_count = self.process_texts(texts)
|
213 |
+
|
214 |
+
used_embeddings = {}
|
215 |
+
chunk_count = max([len(x) for x in batch_chunks])
|
216 |
+
|
217 |
+
zs = []
|
218 |
+
for i in range(chunk_count):
|
219 |
+
batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
|
220 |
+
|
221 |
+
tokens = [x.tokens for x in batch_chunk]
|
222 |
+
multipliers = [x.multipliers for x in batch_chunk]
|
223 |
+
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
224 |
+
|
225 |
+
for fixes in self.hijack.fixes:
|
226 |
+
for position, embedding in fixes:
|
227 |
+
used_embeddings[embedding.name] = embedding
|
228 |
+
|
229 |
+
z = self.process_tokens(tokens, multipliers)
|
230 |
+
zs.append(z)
|
231 |
+
|
232 |
+
if len(used_embeddings) > 0:
|
233 |
+
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
|
234 |
+
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
|
235 |
+
|
236 |
+
return torch.hstack(zs)
|
237 |
+
|
238 |
+
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
239 |
+
"""
|
240 |
+
sends one single prompt chunk to be encoded by transformers neural network.
|
241 |
+
remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
|
242 |
+
there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
|
243 |
+
Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
|
244 |
+
corresponds to one token.
|
245 |
+
"""
|
246 |
+
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
247 |
+
|
248 |
+
# this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
|
249 |
+
if self.id_end != self.id_pad:
|
250 |
+
for batch_pos in range(len(remade_batch_tokens)):
|
251 |
+
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
252 |
+
tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
|
253 |
+
|
254 |
+
z = self.encode_with_transformers(tokens)
|
255 |
+
|
256 |
+
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
257 |
+
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
258 |
+
original_mean = z.mean()
|
259 |
+
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
260 |
+
new_mean = z.mean()
|
261 |
+
z = z * (original_mean / new_mean)
|
262 |
+
|
263 |
+
return z
|
264 |
+
|
265 |
+
|
266 |
+
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
267 |
+
def __init__(self, wrapped, hijack):
|
268 |
+
super().__init__(wrapped, hijack)
|
269 |
+
self.tokenizer = wrapped.tokenizer
|
270 |
+
|
271 |
+
vocab = self.tokenizer.get_vocab()
|
272 |
+
|
273 |
+
self.comma_token = vocab.get(',</w>', None)
|
274 |
+
|
275 |
+
self.token_mults = {}
|
276 |
+
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
277 |
+
for text, ident in tokens_with_parens:
|
278 |
+
mult = 1.0
|
279 |
+
for c in text:
|
280 |
+
if c == '[':
|
281 |
+
mult /= 1.1
|
282 |
+
if c == ']':
|
283 |
+
mult *= 1.1
|
284 |
+
if c == '(':
|
285 |
+
mult *= 1.1
|
286 |
+
if c == ')':
|
287 |
+
mult /= 1.1
|
288 |
+
|
289 |
+
if mult != 1.0:
|
290 |
+
self.token_mults[ident] = mult
|
291 |
+
|
292 |
+
self.id_start = self.wrapped.tokenizer.bos_token_id
|
293 |
+
self.id_end = self.wrapped.tokenizer.eos_token_id
|
294 |
+
self.id_pad = self.id_end
|
295 |
+
|
296 |
+
def tokenize(self, texts):
|
297 |
+
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
298 |
+
|
299 |
+
return tokenized
|
300 |
+
|
301 |
+
def encode_with_transformers(self, tokens):
|
302 |
+
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
303 |
+
|
304 |
+
if opts.CLIP_stop_at_last_layers > 1:
|
305 |
+
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
306 |
+
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
307 |
+
else:
|
308 |
+
z = outputs.last_hidden_state
|
309 |
+
|
310 |
+
return z
|
311 |
+
|
312 |
+
def encode_embedding_init_text(self, init_text, nvpt):
|
313 |
+
embedding_layer = self.wrapped.transformer.text_model.embeddings
|
314 |
+
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
315 |
+
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
316 |
+
|
317 |
+
return embedded
|
sd/stable-diffusion-webui/modules/sd_hijack_clip_old.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules import sd_hijack_clip
|
2 |
+
from modules import shared
|
3 |
+
|
4 |
+
|
5 |
+
def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
6 |
+
id_start = self.id_start
|
7 |
+
id_end = self.id_end
|
8 |
+
maxlen = self.wrapped.max_length # you get to stay at 77
|
9 |
+
used_custom_terms = []
|
10 |
+
remade_batch_tokens = []
|
11 |
+
hijack_comments = []
|
12 |
+
hijack_fixes = []
|
13 |
+
token_count = 0
|
14 |
+
|
15 |
+
cache = {}
|
16 |
+
batch_tokens = self.tokenize(texts)
|
17 |
+
batch_multipliers = []
|
18 |
+
for tokens in batch_tokens:
|
19 |
+
tuple_tokens = tuple(tokens)
|
20 |
+
|
21 |
+
if tuple_tokens in cache:
|
22 |
+
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
23 |
+
else:
|
24 |
+
fixes = []
|
25 |
+
remade_tokens = []
|
26 |
+
multipliers = []
|
27 |
+
mult = 1.0
|
28 |
+
|
29 |
+
i = 0
|
30 |
+
while i < len(tokens):
|
31 |
+
token = tokens[i]
|
32 |
+
|
33 |
+
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
34 |
+
|
35 |
+
mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
|
36 |
+
if mult_change is not None:
|
37 |
+
mult *= mult_change
|
38 |
+
i += 1
|
39 |
+
elif embedding is None:
|
40 |
+
remade_tokens.append(token)
|
41 |
+
multipliers.append(mult)
|
42 |
+
i += 1
|
43 |
+
else:
|
44 |
+
emb_len = int(embedding.vec.shape[0])
|
45 |
+
fixes.append((len(remade_tokens), embedding))
|
46 |
+
remade_tokens += [0] * emb_len
|
47 |
+
multipliers += [mult] * emb_len
|
48 |
+
used_custom_terms.append((embedding.name, embedding.checksum()))
|
49 |
+
i += embedding_length_in_tokens
|
50 |
+
|
51 |
+
if len(remade_tokens) > maxlen - 2:
|
52 |
+
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
53 |
+
ovf = remade_tokens[maxlen - 2:]
|
54 |
+
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
55 |
+
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
56 |
+
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
57 |
+
|
58 |
+
token_count = len(remade_tokens)
|
59 |
+
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
60 |
+
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
61 |
+
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
62 |
+
|
63 |
+
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
64 |
+
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
65 |
+
|
66 |
+
remade_batch_tokens.append(remade_tokens)
|
67 |
+
hijack_fixes.append(fixes)
|
68 |
+
batch_multipliers.append(multipliers)
|
69 |
+
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
70 |
+
|
71 |
+
|
72 |
+
def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
73 |
+
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
|
74 |
+
|
75 |
+
self.hijack.comments += hijack_comments
|
76 |
+
|
77 |
+
if len(used_custom_terms) > 0:
|
78 |
+
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
79 |
+
|
80 |
+
self.hijack.fixes = hijack_fixes
|
81 |
+
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
sd/stable-diffusion-webui/modules/sd_hijack_inpainting.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from einops import repeat
|
5 |
+
from omegaconf import ListConfig
|
6 |
+
|
7 |
+
import ldm.models.diffusion.ddpm
|
8 |
+
import ldm.models.diffusion.ddim
|
9 |
+
import ldm.models.diffusion.plms
|
10 |
+
|
11 |
+
from ldm.models.diffusion.ddpm import LatentDiffusion
|
12 |
+
from ldm.models.diffusion.plms import PLMSSampler
|
13 |
+
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
14 |
+
from ldm.models.diffusion.sampling_util import norm_thresholding
|
15 |
+
|
16 |
+
|
17 |
+
@torch.no_grad()
|
18 |
+
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
19 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
20 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
21 |
+
b, *_, device = *x.shape, x.device
|
22 |
+
|
23 |
+
def get_model_output(x, t):
|
24 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
25 |
+
e_t = self.model.apply_model(x, t, c)
|
26 |
+
else:
|
27 |
+
x_in = torch.cat([x] * 2)
|
28 |
+
t_in = torch.cat([t] * 2)
|
29 |
+
|
30 |
+
if isinstance(c, dict):
|
31 |
+
assert isinstance(unconditional_conditioning, dict)
|
32 |
+
c_in = dict()
|
33 |
+
for k in c:
|
34 |
+
if isinstance(c[k], list):
|
35 |
+
c_in[k] = [
|
36 |
+
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
37 |
+
for i in range(len(c[k]))
|
38 |
+
]
|
39 |
+
else:
|
40 |
+
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
41 |
+
else:
|
42 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
43 |
+
|
44 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
45 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
46 |
+
|
47 |
+
if score_corrector is not None:
|
48 |
+
assert self.model.parameterization == "eps"
|
49 |
+
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
50 |
+
|
51 |
+
return e_t
|
52 |
+
|
53 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
54 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
55 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
56 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
57 |
+
|
58 |
+
def get_x_prev_and_pred_x0(e_t, index):
|
59 |
+
# select parameters corresponding to the currently considered timestep
|
60 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
61 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
62 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
63 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
64 |
+
|
65 |
+
# current prediction for x_0
|
66 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
67 |
+
if quantize_denoised:
|
68 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
69 |
+
if dynamic_threshold is not None:
|
70 |
+
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
71 |
+
# direction pointing to x_t
|
72 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
73 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
74 |
+
if noise_dropout > 0.:
|
75 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
76 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
77 |
+
return x_prev, pred_x0
|
78 |
+
|
79 |
+
e_t = get_model_output(x, t)
|
80 |
+
if len(old_eps) == 0:
|
81 |
+
# Pseudo Improved Euler (2nd order)
|
82 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
83 |
+
e_t_next = get_model_output(x_prev, t_next)
|
84 |
+
e_t_prime = (e_t + e_t_next) / 2
|
85 |
+
elif len(old_eps) == 1:
|
86 |
+
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
87 |
+
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
88 |
+
elif len(old_eps) == 2:
|
89 |
+
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
90 |
+
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
91 |
+
elif len(old_eps) >= 3:
|
92 |
+
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
93 |
+
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
94 |
+
|
95 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
96 |
+
|
97 |
+
return x_prev, pred_x0, e_t
|
98 |
+
|
99 |
+
|
100 |
+
def do_inpainting_hijack():
|
101 |
+
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
102 |
+
|
103 |
+
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
sd/stable-diffusion-webui/modules/sd_hijack_ip2p.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import os.path
|
3 |
+
import sys
|
4 |
+
import gc
|
5 |
+
import time
|
6 |
+
|
7 |
+
def should_hijack_ip2p(checkpoint_info):
|
8 |
+
from modules import sd_models_config
|
9 |
+
|
10 |
+
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
11 |
+
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
12 |
+
|
13 |
+
return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
|
sd/stable-diffusion-webui/modules/sd_hijack_open_clip.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import open_clip.tokenizer
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from modules import sd_hijack_clip, devices
|
5 |
+
from modules.shared import opts
|
6 |
+
|
7 |
+
tokenizer = open_clip.tokenizer._tokenizer
|
8 |
+
|
9 |
+
|
10 |
+
class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
11 |
+
def __init__(self, wrapped, hijack):
|
12 |
+
super().__init__(wrapped, hijack)
|
13 |
+
|
14 |
+
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
15 |
+
self.id_start = tokenizer.encoder["<start_of_text>"]
|
16 |
+
self.id_end = tokenizer.encoder["<end_of_text>"]
|
17 |
+
self.id_pad = 0
|
18 |
+
|
19 |
+
def tokenize(self, texts):
|
20 |
+
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
21 |
+
|
22 |
+
tokenized = [tokenizer.encode(text) for text in texts]
|
23 |
+
|
24 |
+
return tokenized
|
25 |
+
|
26 |
+
def encode_with_transformers(self, tokens):
|
27 |
+
# set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
|
28 |
+
z = self.wrapped.encode_with_transformer(tokens)
|
29 |
+
|
30 |
+
return z
|
31 |
+
|
32 |
+
def encode_embedding_init_text(self, init_text, nvpt):
|
33 |
+
ids = tokenizer.encode(init_text)
|
34 |
+
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
35 |
+
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
36 |
+
|
37 |
+
return embedded
|
sd/stable-diffusion-webui/modules/sd_hijack_optimizations.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
import psutil
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import einsum
|
8 |
+
|
9 |
+
from ldm.util import default
|
10 |
+
from einops import rearrange
|
11 |
+
|
12 |
+
from modules import shared, errors, devices
|
13 |
+
from modules.hypernetworks import hypernetwork
|
14 |
+
|
15 |
+
from .sub_quadratic_attention import efficient_dot_product_attention
|
16 |
+
|
17 |
+
|
18 |
+
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
19 |
+
try:
|
20 |
+
import xformers.ops
|
21 |
+
shared.xformers_available = True
|
22 |
+
except Exception:
|
23 |
+
print("Cannot import xformers", file=sys.stderr)
|
24 |
+
print(traceback.format_exc(), file=sys.stderr)
|
25 |
+
|
26 |
+
|
27 |
+
def get_available_vram():
|
28 |
+
if shared.device.type == 'cuda':
|
29 |
+
stats = torch.cuda.memory_stats(shared.device)
|
30 |
+
mem_active = stats['active_bytes.all.current']
|
31 |
+
mem_reserved = stats['reserved_bytes.all.current']
|
32 |
+
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
33 |
+
mem_free_torch = mem_reserved - mem_active
|
34 |
+
mem_free_total = mem_free_cuda + mem_free_torch
|
35 |
+
return mem_free_total
|
36 |
+
else:
|
37 |
+
return psutil.virtual_memory().available
|
38 |
+
|
39 |
+
|
40 |
+
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
41 |
+
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
42 |
+
h = self.heads
|
43 |
+
|
44 |
+
q_in = self.to_q(x)
|
45 |
+
context = default(context, x)
|
46 |
+
|
47 |
+
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
48 |
+
k_in = self.to_k(context_k)
|
49 |
+
v_in = self.to_v(context_v)
|
50 |
+
del context, context_k, context_v, x
|
51 |
+
|
52 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
53 |
+
del q_in, k_in, v_in
|
54 |
+
|
55 |
+
dtype = q.dtype
|
56 |
+
if shared.opts.upcast_attn:
|
57 |
+
q, k, v = q.float(), k.float(), v.float()
|
58 |
+
|
59 |
+
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
60 |
+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
61 |
+
for i in range(0, q.shape[0], 2):
|
62 |
+
end = i + 2
|
63 |
+
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
64 |
+
s1 *= self.scale
|
65 |
+
|
66 |
+
s2 = s1.softmax(dim=-1)
|
67 |
+
del s1
|
68 |
+
|
69 |
+
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
70 |
+
del s2
|
71 |
+
del q, k, v
|
72 |
+
|
73 |
+
r1 = r1.to(dtype)
|
74 |
+
|
75 |
+
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
76 |
+
del r1
|
77 |
+
|
78 |
+
return self.to_out(r2)
|
79 |
+
|
80 |
+
|
81 |
+
# taken from https://github.com/Doggettx/stable-diffusion and modified
|
82 |
+
def split_cross_attention_forward(self, x, context=None, mask=None):
|
83 |
+
h = self.heads
|
84 |
+
|
85 |
+
q_in = self.to_q(x)
|
86 |
+
context = default(context, x)
|
87 |
+
|
88 |
+
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
89 |
+
k_in = self.to_k(context_k)
|
90 |
+
v_in = self.to_v(context_v)
|
91 |
+
|
92 |
+
dtype = q_in.dtype
|
93 |
+
if shared.opts.upcast_attn:
|
94 |
+
q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
|
95 |
+
|
96 |
+
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
97 |
+
k_in = k_in * self.scale
|
98 |
+
|
99 |
+
del context, x
|
100 |
+
|
101 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
102 |
+
del q_in, k_in, v_in
|
103 |
+
|
104 |
+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
105 |
+
|
106 |
+
mem_free_total = get_available_vram()
|
107 |
+
|
108 |
+
gb = 1024 ** 3
|
109 |
+
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
110 |
+
modifier = 3 if q.element_size() == 2 else 2.5
|
111 |
+
mem_required = tensor_size * modifier
|
112 |
+
steps = 1
|
113 |
+
|
114 |
+
if mem_required > mem_free_total:
|
115 |
+
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
116 |
+
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
117 |
+
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
118 |
+
|
119 |
+
if steps > 64:
|
120 |
+
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
121 |
+
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
122 |
+
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
123 |
+
|
124 |
+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
125 |
+
for i in range(0, q.shape[1], slice_size):
|
126 |
+
end = i + slice_size
|
127 |
+
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
128 |
+
|
129 |
+
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
130 |
+
del s1
|
131 |
+
|
132 |
+
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
133 |
+
del s2
|
134 |
+
|
135 |
+
del q, k, v
|
136 |
+
|
137 |
+
r1 = r1.to(dtype)
|
138 |
+
|
139 |
+
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
140 |
+
del r1
|
141 |
+
|
142 |
+
return self.to_out(r2)
|
143 |
+
|
144 |
+
|
145 |
+
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
146 |
+
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
147 |
+
|
148 |
+
def einsum_op_compvis(q, k, v):
|
149 |
+
s = einsum('b i d, b j d -> b i j', q, k)
|
150 |
+
s = s.softmax(dim=-1, dtype=s.dtype)
|
151 |
+
return einsum('b i j, b j d -> b i d', s, v)
|
152 |
+
|
153 |
+
def einsum_op_slice_0(q, k, v, slice_size):
|
154 |
+
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
155 |
+
for i in range(0, q.shape[0], slice_size):
|
156 |
+
end = i + slice_size
|
157 |
+
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
158 |
+
return r
|
159 |
+
|
160 |
+
def einsum_op_slice_1(q, k, v, slice_size):
|
161 |
+
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
162 |
+
for i in range(0, q.shape[1], slice_size):
|
163 |
+
end = i + slice_size
|
164 |
+
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
165 |
+
return r
|
166 |
+
|
167 |
+
def einsum_op_mps_v1(q, k, v):
|
168 |
+
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
169 |
+
return einsum_op_compvis(q, k, v)
|
170 |
+
else:
|
171 |
+
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
172 |
+
if slice_size % 4096 == 0:
|
173 |
+
slice_size -= 1
|
174 |
+
return einsum_op_slice_1(q, k, v, slice_size)
|
175 |
+
|
176 |
+
def einsum_op_mps_v2(q, k, v):
|
177 |
+
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
178 |
+
return einsum_op_compvis(q, k, v)
|
179 |
+
else:
|
180 |
+
return einsum_op_slice_0(q, k, v, 1)
|
181 |
+
|
182 |
+
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
183 |
+
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
184 |
+
if size_mb <= max_tensor_mb:
|
185 |
+
return einsum_op_compvis(q, k, v)
|
186 |
+
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
187 |
+
if div <= q.shape[0]:
|
188 |
+
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
189 |
+
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
190 |
+
|
191 |
+
def einsum_op_cuda(q, k, v):
|
192 |
+
stats = torch.cuda.memory_stats(q.device)
|
193 |
+
mem_active = stats['active_bytes.all.current']
|
194 |
+
mem_reserved = stats['reserved_bytes.all.current']
|
195 |
+
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
196 |
+
mem_free_torch = mem_reserved - mem_active
|
197 |
+
mem_free_total = mem_free_cuda + mem_free_torch
|
198 |
+
# Divide factor of safety as there's copying and fragmentation
|
199 |
+
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
200 |
+
|
201 |
+
def einsum_op(q, k, v):
|
202 |
+
if q.device.type == 'cuda':
|
203 |
+
return einsum_op_cuda(q, k, v)
|
204 |
+
|
205 |
+
if q.device.type == 'mps':
|
206 |
+
if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
|
207 |
+
return einsum_op_mps_v1(q, k, v)
|
208 |
+
return einsum_op_mps_v2(q, k, v)
|
209 |
+
|
210 |
+
# Smaller slices are faster due to L2/L3/SLC caches.
|
211 |
+
# Tested on i7 with 8MB L3 cache.
|
212 |
+
return einsum_op_tensor_mem(q, k, v, 32)
|
213 |
+
|
214 |
+
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
215 |
+
h = self.heads
|
216 |
+
|
217 |
+
q = self.to_q(x)
|
218 |
+
context = default(context, x)
|
219 |
+
|
220 |
+
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
221 |
+
k = self.to_k(context_k)
|
222 |
+
v = self.to_v(context_v)
|
223 |
+
del context, context_k, context_v, x
|
224 |
+
|
225 |
+
dtype = q.dtype
|
226 |
+
if shared.opts.upcast_attn:
|
227 |
+
q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
|
228 |
+
|
229 |
+
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
230 |
+
k = k * self.scale
|
231 |
+
|
232 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
233 |
+
r = einsum_op(q, k, v)
|
234 |
+
r = r.to(dtype)
|
235 |
+
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
236 |
+
|
237 |
+
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
238 |
+
|
239 |
+
|
240 |
+
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
241 |
+
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
242 |
+
def sub_quad_attention_forward(self, x, context=None, mask=None):
|
243 |
+
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
244 |
+
|
245 |
+
h = self.heads
|
246 |
+
|
247 |
+
q = self.to_q(x)
|
248 |
+
context = default(context, x)
|
249 |
+
|
250 |
+
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
251 |
+
k = self.to_k(context_k)
|
252 |
+
v = self.to_v(context_v)
|
253 |
+
del context, context_k, context_v, x
|
254 |
+
|
255 |
+
q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
256 |
+
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
257 |
+
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
258 |
+
|
259 |
+
dtype = q.dtype
|
260 |
+
if shared.opts.upcast_attn:
|
261 |
+
q, k = q.float(), k.float()
|
262 |
+
|
263 |
+
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
264 |
+
|
265 |
+
x = x.to(dtype)
|
266 |
+
|
267 |
+
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
268 |
+
|
269 |
+
out_proj, dropout = self.to_out
|
270 |
+
x = out_proj(x)
|
271 |
+
x = dropout(x)
|
272 |
+
|
273 |
+
return x
|
274 |
+
|
275 |
+
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
276 |
+
bytes_per_token = torch.finfo(q.dtype).bits//8
|
277 |
+
batch_x_heads, q_tokens, _ = q.shape
|
278 |
+
_, k_tokens, _ = k.shape
|
279 |
+
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
280 |
+
|
281 |
+
if chunk_threshold is None:
|
282 |
+
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
283 |
+
elif chunk_threshold == 0:
|
284 |
+
chunk_threshold_bytes = None
|
285 |
+
else:
|
286 |
+
chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
|
287 |
+
|
288 |
+
if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
|
289 |
+
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
|
290 |
+
elif kv_chunk_size_min == 0:
|
291 |
+
kv_chunk_size_min = None
|
292 |
+
|
293 |
+
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
294 |
+
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
295 |
+
# i.e. send it down the unchunked fast-path
|
296 |
+
query_chunk_size = q_tokens
|
297 |
+
kv_chunk_size = k_tokens
|
298 |
+
|
299 |
+
with devices.without_autocast(disable=q.dtype == v.dtype):
|
300 |
+
return efficient_dot_product_attention(
|
301 |
+
q,
|
302 |
+
k,
|
303 |
+
v,
|
304 |
+
query_chunk_size=q_chunk_size,
|
305 |
+
kv_chunk_size=kv_chunk_size,
|
306 |
+
kv_chunk_size_min = kv_chunk_size_min,
|
307 |
+
use_checkpoint=use_checkpoint,
|
308 |
+
)
|
309 |
+
|
310 |
+
|
311 |
+
def get_xformers_flash_attention_op(q, k, v):
|
312 |
+
if not shared.cmd_opts.xformers_flash_attention:
|
313 |
+
return None
|
314 |
+
|
315 |
+
try:
|
316 |
+
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
|
317 |
+
fw, bw = flash_attention_op
|
318 |
+
if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
|
319 |
+
return flash_attention_op
|
320 |
+
except Exception as e:
|
321 |
+
errors.display_once(e, "enabling flash attention")
|
322 |
+
|
323 |
+
return None
|
324 |
+
|
325 |
+
|
326 |
+
def xformers_attention_forward(self, x, context=None, mask=None):
|
327 |
+
h = self.heads
|
328 |
+
q_in = self.to_q(x)
|
329 |
+
context = default(context, x)
|
330 |
+
|
331 |
+
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
332 |
+
k_in = self.to_k(context_k)
|
333 |
+
v_in = self.to_v(context_v)
|
334 |
+
|
335 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
336 |
+
del q_in, k_in, v_in
|
337 |
+
|
338 |
+
dtype = q.dtype
|
339 |
+
if shared.opts.upcast_attn:
|
340 |
+
q, k = q.float(), k.float()
|
341 |
+
|
342 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
|
343 |
+
|
344 |
+
out = out.to(dtype)
|
345 |
+
|
346 |
+
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
347 |
+
return self.to_out(out)
|
348 |
+
|
349 |
+
def cross_attention_attnblock_forward(self, x):
|
350 |
+
h_ = x
|
351 |
+
h_ = self.norm(h_)
|
352 |
+
q1 = self.q(h_)
|
353 |
+
k1 = self.k(h_)
|
354 |
+
v = self.v(h_)
|
355 |
+
|
356 |
+
# compute attention
|
357 |
+
b, c, h, w = q1.shape
|
358 |
+
|
359 |
+
q2 = q1.reshape(b, c, h*w)
|
360 |
+
del q1
|
361 |
+
|
362 |
+
q = q2.permute(0, 2, 1) # b,hw,c
|
363 |
+
del q2
|
364 |
+
|
365 |
+
k = k1.reshape(b, c, h*w) # b,c,hw
|
366 |
+
del k1
|
367 |
+
|
368 |
+
h_ = torch.zeros_like(k, device=q.device)
|
369 |
+
|
370 |
+
mem_free_total = get_available_vram()
|
371 |
+
|
372 |
+
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
373 |
+
mem_required = tensor_size * 2.5
|
374 |
+
steps = 1
|
375 |
+
|
376 |
+
if mem_required > mem_free_total:
|
377 |
+
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
378 |
+
|
379 |
+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
380 |
+
for i in range(0, q.shape[1], slice_size):
|
381 |
+
end = i + slice_size
|
382 |
+
|
383 |
+
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
384 |
+
w2 = w1 * (int(c)**(-0.5))
|
385 |
+
del w1
|
386 |
+
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
387 |
+
del w2
|
388 |
+
|
389 |
+
# attend to values
|
390 |
+
v1 = v.reshape(b, c, h*w)
|
391 |
+
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
392 |
+
del w3
|
393 |
+
|
394 |
+
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
395 |
+
del v1, w4
|
396 |
+
|
397 |
+
h2 = h_.reshape(b, c, h, w)
|
398 |
+
del h_
|
399 |
+
|
400 |
+
h3 = self.proj_out(h2)
|
401 |
+
del h2
|
402 |
+
|
403 |
+
h3 += x
|
404 |
+
|
405 |
+
return h3
|
406 |
+
|
407 |
+
def xformers_attnblock_forward(self, x):
|
408 |
+
try:
|
409 |
+
h_ = x
|
410 |
+
h_ = self.norm(h_)
|
411 |
+
q = self.q(h_)
|
412 |
+
k = self.k(h_)
|
413 |
+
v = self.v(h_)
|
414 |
+
b, c, h, w = q.shape
|
415 |
+
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
416 |
+
dtype = q.dtype
|
417 |
+
if shared.opts.upcast_attn:
|
418 |
+
q, k = q.float(), k.float()
|
419 |
+
q = q.contiguous()
|
420 |
+
k = k.contiguous()
|
421 |
+
v = v.contiguous()
|
422 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
|
423 |
+
out = out.to(dtype)
|
424 |
+
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
425 |
+
out = self.proj_out(out)
|
426 |
+
return x + out
|
427 |
+
except NotImplementedError:
|
428 |
+
return cross_attention_attnblock_forward(self, x)
|
429 |
+
|
430 |
+
def sub_quad_attnblock_forward(self, x):
|
431 |
+
h_ = x
|
432 |
+
h_ = self.norm(h_)
|
433 |
+
q = self.q(h_)
|
434 |
+
k = self.k(h_)
|
435 |
+
v = self.v(h_)
|
436 |
+
b, c, h, w = q.shape
|
437 |
+
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
438 |
+
q = q.contiguous()
|
439 |
+
k = k.contiguous()
|
440 |
+
v = v.contiguous()
|
441 |
+
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
442 |
+
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
443 |
+
out = self.proj_out(out)
|
444 |
+
return x + out
|
sd/stable-diffusion-webui/modules/sd_hijack_unet.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from packaging import version
|
3 |
+
|
4 |
+
from modules import devices
|
5 |
+
from modules.sd_hijack_utils import CondFunc
|
6 |
+
|
7 |
+
|
8 |
+
class TorchHijackForUnet:
|
9 |
+
"""
|
10 |
+
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
11 |
+
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __getattr__(self, item):
|
15 |
+
if item == 'cat':
|
16 |
+
return self.cat
|
17 |
+
|
18 |
+
if hasattr(torch, item):
|
19 |
+
return getattr(torch, item)
|
20 |
+
|
21 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
22 |
+
|
23 |
+
def cat(self, tensors, *args, **kwargs):
|
24 |
+
if len(tensors) == 2:
|
25 |
+
a, b = tensors
|
26 |
+
if a.shape[-2:] != b.shape[-2:]:
|
27 |
+
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
28 |
+
|
29 |
+
tensors = (a, b)
|
30 |
+
|
31 |
+
return torch.cat(tensors, *args, **kwargs)
|
32 |
+
|
33 |
+
|
34 |
+
th = TorchHijackForUnet()
|
35 |
+
|
36 |
+
|
37 |
+
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
38 |
+
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
39 |
+
|
40 |
+
if isinstance(cond, dict):
|
41 |
+
for y in cond.keys():
|
42 |
+
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
43 |
+
|
44 |
+
with devices.autocast():
|
45 |
+
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
46 |
+
|
47 |
+
|
48 |
+
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
49 |
+
def __init__(self, *args, **kwargs):
|
50 |
+
torch.nn.GELU.__init__(self, *args, **kwargs)
|
51 |
+
def forward(self, x):
|
52 |
+
if devices.unet_needs_upcast:
|
53 |
+
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
54 |
+
else:
|
55 |
+
return torch.nn.GELU.forward(self, x)
|
56 |
+
|
57 |
+
|
58 |
+
ddpm_edit_hijack = None
|
59 |
+
def hijack_ddpm_edit():
|
60 |
+
global ddpm_edit_hijack
|
61 |
+
if not ddpm_edit_hijack:
|
62 |
+
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
63 |
+
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
64 |
+
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
65 |
+
|
66 |
+
|
67 |
+
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
68 |
+
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
69 |
+
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
70 |
+
if version.parse(torch.__version__) <= version.parse("1.13.1"):
|
71 |
+
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
72 |
+
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
73 |
+
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
74 |
+
|
75 |
+
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
76 |
+
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
77 |
+
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
78 |
+
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
79 |
+
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
sd/stable-diffusion-webui/modules/sd_hijack_utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
class CondFunc:
|
4 |
+
def __new__(cls, orig_func, sub_func, cond_func):
|
5 |
+
self = super(CondFunc, cls).__new__(cls)
|
6 |
+
if isinstance(orig_func, str):
|
7 |
+
func_path = orig_func.split('.')
|
8 |
+
for i in range(len(func_path)-1, -1, -1):
|
9 |
+
try:
|
10 |
+
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
11 |
+
break
|
12 |
+
except ImportError:
|
13 |
+
pass
|
14 |
+
for attr_name in func_path[i:-1]:
|
15 |
+
resolved_obj = getattr(resolved_obj, attr_name)
|
16 |
+
orig_func = getattr(resolved_obj, func_path[-1])
|
17 |
+
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
18 |
+
self.__init__(orig_func, sub_func, cond_func)
|
19 |
+
return lambda *args, **kwargs: self(*args, **kwargs)
|
20 |
+
def __init__(self, orig_func, sub_func, cond_func):
|
21 |
+
self.__orig_func = orig_func
|
22 |
+
self.__sub_func = sub_func
|
23 |
+
self.__cond_func = cond_func
|
24 |
+
def __call__(self, *args, **kwargs):
|
25 |
+
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
|
26 |
+
return self.__sub_func(self.__orig_func, *args, **kwargs)
|
27 |
+
else:
|
28 |
+
return self.__orig_func(*args, **kwargs)
|