AdamOswald1 commited on
Commit
6fa895d
·
1 Parent(s): 3f2f51a

Upload 77 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. sd/stable-diffusion-webui/modules/call_queue.py +109 -0
  2. sd/stable-diffusion-webui/modules/codeformer_model.py +143 -0
  3. sd/stable-diffusion-webui/modules/deepbooru.py +99 -0
  4. sd/stable-diffusion-webui/modules/deepbooru_model.py +678 -0
  5. sd/stable-diffusion-webui/modules/devices.py +152 -0
  6. sd/stable-diffusion-webui/modules/errors.py +43 -0
  7. sd/stable-diffusion-webui/modules/esrgan_model.py +233 -0
  8. sd/stable-diffusion-webui/modules/esrgan_model_arch.py +464 -0
  9. sd/stable-diffusion-webui/modules/extensions.py +107 -0
  10. sd/stable-diffusion-webui/modules/extra_networks.py +147 -0
  11. sd/stable-diffusion-webui/modules/extra_networks_hypernet.py +27 -0
  12. sd/stable-diffusion-webui/modules/extras.py +258 -0
  13. sd/stable-diffusion-webui/modules/face_restoration.py +19 -0
  14. sd/stable-diffusion-webui/modules/generation_parameters_copypaste.py +402 -0
  15. sd/stable-diffusion-webui/modules/gfpgan_model.py +116 -0
  16. sd/stable-diffusion-webui/modules/hashes.py +91 -0
  17. sd/stable-diffusion-webui/modules/images.py +669 -0
  18. sd/stable-diffusion-webui/modules/img2img.py +184 -0
  19. sd/stable-diffusion-webui/modules/import_hook.py +5 -0
  20. sd/stable-diffusion-webui/modules/interrogate.py +227 -0
  21. sd/stable-diffusion-webui/modules/localization.py +37 -0
  22. sd/stable-diffusion-webui/modules/lowvram.py +96 -0
  23. sd/stable-diffusion-webui/modules/mac_specific.py +53 -0
  24. sd/stable-diffusion-webui/modules/masking.py +99 -0
  25. sd/stable-diffusion-webui/modules/memmon.py +88 -0
  26. sd/stable-diffusion-webui/modules/modelloader.py +172 -0
  27. sd/stable-diffusion-webui/modules/ngrok.py +26 -0
  28. sd/stable-diffusion-webui/modules/paths.py +62 -0
  29. sd/stable-diffusion-webui/modules/postprocessing.py +103 -0
  30. sd/stable-diffusion-webui/modules/processing.py +1056 -0
  31. sd/stable-diffusion-webui/modules/progress.py +99 -0
  32. sd/stable-diffusion-webui/modules/prompt_parser.py +373 -0
  33. sd/stable-diffusion-webui/modules/realesrgan_model.py +129 -0
  34. sd/stable-diffusion-webui/modules/safe.py +192 -0
  35. sd/stable-diffusion-webui/modules/script_callbacks.py +359 -0
  36. sd/stable-diffusion-webui/modules/script_loading.py +32 -0
  37. sd/stable-diffusion-webui/modules/scripts.py +501 -0
  38. sd/stable-diffusion-webui/modules/scripts_auto_postprocessing.py +42 -0
  39. sd/stable-diffusion-webui/modules/scripts_postprocessing.py +152 -0
  40. sd/stable-diffusion-webui/modules/sd_disable_initialization.py +93 -0
  41. sd/stable-diffusion-webui/modules/sd_hijack.py +264 -0
  42. sd/stable-diffusion-webui/modules/sd_hijack_checkpoint.py +46 -0
  43. sd/stable-diffusion-webui/modules/sd_hijack_clip.py +317 -0
  44. sd/stable-diffusion-webui/modules/sd_hijack_clip_old.py +81 -0
  45. sd/stable-diffusion-webui/modules/sd_hijack_inpainting.py +103 -0
  46. sd/stable-diffusion-webui/modules/sd_hijack_ip2p.py +13 -0
  47. sd/stable-diffusion-webui/modules/sd_hijack_open_clip.py +37 -0
  48. sd/stable-diffusion-webui/modules/sd_hijack_optimizations.py +444 -0
  49. sd/stable-diffusion-webui/modules/sd_hijack_unet.py +79 -0
  50. sd/stable-diffusion-webui/modules/sd_hijack_utils.py +28 -0
sd/stable-diffusion-webui/modules/call_queue.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import sys
3
+ import threading
4
+ import traceback
5
+ import time
6
+
7
+ from modules import shared, progress
8
+
9
+ queue_lock = threading.Lock()
10
+
11
+
12
+ def wrap_queued_call(func):
13
+ def f(*args, **kwargs):
14
+ with queue_lock:
15
+ res = func(*args, **kwargs)
16
+
17
+ return res
18
+
19
+ return f
20
+
21
+
22
+ def wrap_gradio_gpu_call(func, extra_outputs=None):
23
+ def f(*args, **kwargs):
24
+
25
+ # if the first argument is a string that says "task(...)", it is treated as a job id
26
+ if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
27
+ id_task = args[0]
28
+ progress.add_task_to_queue(id_task)
29
+ else:
30
+ id_task = None
31
+
32
+ with queue_lock:
33
+ shared.state.begin()
34
+ progress.start_task(id_task)
35
+
36
+ try:
37
+ res = func(*args, **kwargs)
38
+ finally:
39
+ progress.finish_task(id_task)
40
+
41
+ shared.state.end()
42
+
43
+ return res
44
+
45
+ return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
46
+
47
+
48
+ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
49
+ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
50
+ run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
51
+ if run_memmon:
52
+ shared.mem_mon.monitor()
53
+ t = time.perf_counter()
54
+
55
+ try:
56
+ res = list(func(*args, **kwargs))
57
+ except Exception as e:
58
+ # When printing out our debug argument list, do not print out more than a MB of text
59
+ max_debug_str_len = 131072 # (1024*1024)/8
60
+
61
+ print("Error completing request", file=sys.stderr)
62
+ argStr = f"Arguments: {str(args)} {str(kwargs)}"
63
+ print(argStr[:max_debug_str_len], file=sys.stderr)
64
+ if len(argStr) > max_debug_str_len:
65
+ print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
66
+
67
+ print(traceback.format_exc(), file=sys.stderr)
68
+
69
+ shared.state.job = ""
70
+ shared.state.job_count = 0
71
+
72
+ if extra_outputs_array is None:
73
+ extra_outputs_array = [None, '']
74
+
75
+ res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
76
+
77
+ shared.state.skipped = False
78
+ shared.state.interrupted = False
79
+ shared.state.job_count = 0
80
+
81
+ if not add_stats:
82
+ return tuple(res)
83
+
84
+ elapsed = time.perf_counter() - t
85
+ elapsed_m = int(elapsed // 60)
86
+ elapsed_s = elapsed % 60
87
+ elapsed_text = f"{elapsed_s:.2f}s"
88
+ if elapsed_m > 0:
89
+ elapsed_text = f"{elapsed_m}m "+elapsed_text
90
+
91
+ if run_memmon:
92
+ mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
93
+ active_peak = mem_stats['active_peak']
94
+ reserved_peak = mem_stats['reserved_peak']
95
+ sys_peak = mem_stats['system_peak']
96
+ sys_total = mem_stats['total']
97
+ sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
98
+
99
+ vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
100
+ else:
101
+ vram_html = ''
102
+
103
+ # last item is always HTML
104
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
105
+
106
+ return tuple(res)
107
+
108
+ return f
109
+
sd/stable-diffusion-webui/modules/codeformer_model.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import cv2
6
+ import torch
7
+
8
+ import modules.face_restoration
9
+ import modules.shared
10
+ from modules import shared, devices, modelloader
11
+ from modules.paths import models_path
12
+
13
+ # codeformer people made a choice to include modified basicsr library to their project which makes
14
+ # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
15
+ # I am making a choice to include some files from codeformer to work around this issue.
16
+ model_dir = "Codeformer"
17
+ model_path = os.path.join(models_path, model_dir)
18
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
19
+
20
+ have_codeformer = False
21
+ codeformer = None
22
+
23
+
24
+ def setup_model(dirname):
25
+ global model_path
26
+ if not os.path.exists(model_path):
27
+ os.makedirs(model_path)
28
+
29
+ path = modules.paths.paths.get("CodeFormer", None)
30
+ if path is None:
31
+ return
32
+
33
+ try:
34
+ from torchvision.transforms.functional import normalize
35
+ from modules.codeformer.codeformer_arch import CodeFormer
36
+ from basicsr.utils.download_util import load_file_from_url
37
+ from basicsr.utils import imwrite, img2tensor, tensor2img
38
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
39
+ from facelib.detection.retinaface import retinaface
40
+ from modules.shared import cmd_opts
41
+
42
+ net_class = CodeFormer
43
+
44
+ class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
45
+ def name(self):
46
+ return "CodeFormer"
47
+
48
+ def __init__(self, dirname):
49
+ self.net = None
50
+ self.face_helper = None
51
+ self.cmd_dir = dirname
52
+
53
+ def create_models(self):
54
+
55
+ if self.net is not None and self.face_helper is not None:
56
+ self.net.to(devices.device_codeformer)
57
+ return self.net, self.face_helper
58
+ model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
59
+ if len(model_paths) != 0:
60
+ ckpt_path = model_paths[0]
61
+ else:
62
+ print("Unable to load codeformer model.")
63
+ return None, None
64
+ net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
65
+ checkpoint = torch.load(ckpt_path)['params_ema']
66
+ net.load_state_dict(checkpoint)
67
+ net.eval()
68
+
69
+ if hasattr(retinaface, 'device'):
70
+ retinaface.device = devices.device_codeformer
71
+ face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
72
+
73
+ self.net = net
74
+ self.face_helper = face_helper
75
+
76
+ return net, face_helper
77
+
78
+ def send_model_to(self, device):
79
+ self.net.to(device)
80
+ self.face_helper.face_det.to(device)
81
+ self.face_helper.face_parse.to(device)
82
+
83
+ def restore(self, np_image, w=None):
84
+ np_image = np_image[:, :, ::-1]
85
+
86
+ original_resolution = np_image.shape[0:2]
87
+
88
+ self.create_models()
89
+ if self.net is None or self.face_helper is None:
90
+ return np_image
91
+
92
+ self.send_model_to(devices.device_codeformer)
93
+
94
+ self.face_helper.clean_all()
95
+ self.face_helper.read_image(np_image)
96
+ self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
97
+ self.face_helper.align_warp_face()
98
+
99
+ for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
100
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
101
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
102
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
103
+
104
+ try:
105
+ with torch.no_grad():
106
+ output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
107
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
108
+ del output
109
+ torch.cuda.empty_cache()
110
+ except Exception as error:
111
+ print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
112
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
113
+
114
+ restored_face = restored_face.astype('uint8')
115
+ self.face_helper.add_restored_face(restored_face)
116
+
117
+ self.face_helper.get_inverse_affine(None)
118
+
119
+ restored_img = self.face_helper.paste_faces_to_input_image()
120
+ restored_img = restored_img[:, :, ::-1]
121
+
122
+ if original_resolution != restored_img.shape[0:2]:
123
+ restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
124
+
125
+ self.face_helper.clean_all()
126
+
127
+ if shared.opts.face_restoration_unload:
128
+ self.send_model_to(devices.cpu)
129
+
130
+ return restored_img
131
+
132
+ global have_codeformer
133
+ have_codeformer = True
134
+
135
+ global codeformer
136
+ codeformer = FaceRestorerCodeFormer(dirname)
137
+ shared.face_restorers.append(codeformer)
138
+
139
+ except Exception:
140
+ print("Error setting up CodeFormer:", file=sys.stderr)
141
+ print(traceback.format_exc(), file=sys.stderr)
142
+
143
+ # sys.path = stored_sys_path
sd/stable-diffusion-webui/modules/deepbooru.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ from modules import modelloader, paths, deepbooru_model, devices, images, shared
9
+
10
+ re_special = re.compile(r'([\\()])')
11
+
12
+
13
+ class DeepDanbooru:
14
+ def __init__(self):
15
+ self.model = None
16
+
17
+ def load(self):
18
+ if self.model is not None:
19
+ return
20
+
21
+ files = modelloader.load_models(
22
+ model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
23
+ model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
24
+ ext_filter=[".pt"],
25
+ download_name='model-resnet_custom_v3.pt',
26
+ )
27
+
28
+ self.model = deepbooru_model.DeepDanbooruModel()
29
+ self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
30
+
31
+ self.model.eval()
32
+ self.model.to(devices.cpu, devices.dtype)
33
+
34
+ def start(self):
35
+ self.load()
36
+ self.model.to(devices.device)
37
+
38
+ def stop(self):
39
+ if not shared.opts.interrogate_keep_models_in_memory:
40
+ self.model.to(devices.cpu)
41
+ devices.torch_gc()
42
+
43
+ def tag(self, pil_image):
44
+ self.start()
45
+ res = self.tag_multi(pil_image)
46
+ self.stop()
47
+
48
+ return res
49
+
50
+ def tag_multi(self, pil_image, force_disable_ranks=False):
51
+ threshold = shared.opts.interrogate_deepbooru_score_threshold
52
+ use_spaces = shared.opts.deepbooru_use_spaces
53
+ use_escape = shared.opts.deepbooru_escape
54
+ alpha_sort = shared.opts.deepbooru_sort_alpha
55
+ include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
56
+
57
+ pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
58
+ a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
59
+
60
+ with torch.no_grad(), devices.autocast():
61
+ x = torch.from_numpy(a).to(devices.device)
62
+ y = self.model(x)[0].detach().cpu().numpy()
63
+
64
+ probability_dict = {}
65
+
66
+ for tag, probability in zip(self.model.tags, y):
67
+ if probability < threshold:
68
+ continue
69
+
70
+ if tag.startswith("rating:"):
71
+ continue
72
+
73
+ probability_dict[tag] = probability
74
+
75
+ if alpha_sort:
76
+ tags = sorted(probability_dict)
77
+ else:
78
+ tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
79
+
80
+ res = []
81
+
82
+ filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
83
+
84
+ for tag in [x for x in tags if x not in filtertags]:
85
+ probability = probability_dict[tag]
86
+ tag_outformat = tag
87
+ if use_spaces:
88
+ tag_outformat = tag_outformat.replace('_', ' ')
89
+ if use_escape:
90
+ tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
91
+ if include_ranks:
92
+ tag_outformat = f"({tag_outformat}:{probability:.3f})"
93
+
94
+ res.append(tag_outformat)
95
+
96
+ return ", ".join(res)
97
+
98
+
99
+ model = DeepDanbooru()
sd/stable-diffusion-webui/modules/deepbooru_model.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from modules import devices
6
+
7
+ # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
8
+
9
+
10
+ class DeepDanbooruModel(nn.Module):
11
+ def __init__(self):
12
+ super(DeepDanbooruModel, self).__init__()
13
+
14
+ self.tags = []
15
+
16
+ self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
17
+ self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
18
+ self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
19
+ self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
20
+ self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
21
+ self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
22
+ self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
23
+ self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
24
+ self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
25
+ self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
26
+ self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
27
+ self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
28
+ self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
29
+ self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
30
+ self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
31
+ self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
32
+ self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
33
+ self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
34
+ self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
35
+ self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
36
+ self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
37
+ self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
38
+ self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
39
+ self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
40
+ self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
41
+ self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
42
+ self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
43
+ self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
44
+ self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
45
+ self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
46
+ self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
47
+ self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
48
+ self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
49
+ self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
50
+ self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
51
+ self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
52
+ self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
53
+ self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
54
+ self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
55
+ self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
56
+ self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
57
+ self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
58
+ self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
59
+ self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
60
+ self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
61
+ self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
62
+ self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
63
+ self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
64
+ self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
65
+ self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
66
+ self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
67
+ self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
68
+ self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
69
+ self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
70
+ self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
71
+ self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
72
+ self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
73
+ self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
74
+ self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
75
+ self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
76
+ self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
77
+ self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
78
+ self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
79
+ self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
80
+ self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
81
+ self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
82
+ self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
83
+ self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
84
+ self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
85
+ self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
86
+ self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
87
+ self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
88
+ self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
89
+ self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
90
+ self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
91
+ self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
92
+ self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
93
+ self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
94
+ self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
95
+ self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
96
+ self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
97
+ self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
98
+ self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
99
+ self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
100
+ self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
101
+ self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
102
+ self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
103
+ self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
104
+ self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
105
+ self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
106
+ self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
107
+ self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
108
+ self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
109
+ self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
110
+ self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
111
+ self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
112
+ self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
113
+ self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
114
+ self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
115
+ self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
116
+ self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
117
+ self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
118
+ self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
119
+ self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
120
+ self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
121
+ self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
122
+ self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
123
+ self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
124
+ self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
125
+ self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
126
+ self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
127
+ self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
128
+ self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
129
+ self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
130
+ self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
131
+ self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
132
+ self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
133
+ self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
134
+ self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
135
+ self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
136
+ self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
137
+ self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
138
+ self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
139
+ self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
140
+ self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
141
+ self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
142
+ self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
143
+ self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
144
+ self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
145
+ self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
146
+ self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
147
+ self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
148
+ self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
149
+ self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
150
+ self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
151
+ self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
152
+ self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
153
+ self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
154
+ self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
155
+ self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
156
+ self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
157
+ self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
158
+ self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
159
+ self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
160
+ self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
161
+ self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
162
+ self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
163
+ self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
164
+ self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
165
+ self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
166
+ self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
167
+ self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
168
+ self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
169
+ self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
170
+ self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
171
+ self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
172
+ self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
173
+ self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
174
+ self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
175
+ self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
176
+ self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
177
+ self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
178
+ self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
179
+ self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
180
+ self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
181
+ self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
182
+ self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
183
+ self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
184
+ self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
185
+ self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
186
+ self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
187
+ self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
188
+ self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
189
+ self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
190
+ self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
191
+ self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
192
+ self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
193
+ self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
194
+ self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
195
+ self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
196
+
197
+ def forward(self, *inputs):
198
+ t_358, = inputs
199
+ t_359 = t_358.permute(*[0, 3, 1, 2])
200
+ t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
201
+ t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
202
+ t_361 = F.relu(t_360)
203
+ t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
204
+ t_362 = self.n_MaxPool_0(t_361)
205
+ t_363 = self.n_Conv_1(t_362)
206
+ t_364 = self.n_Conv_2(t_362)
207
+ t_365 = F.relu(t_364)
208
+ t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
209
+ t_366 = self.n_Conv_3(t_365_padded)
210
+ t_367 = F.relu(t_366)
211
+ t_368 = self.n_Conv_4(t_367)
212
+ t_369 = torch.add(t_368, t_363)
213
+ t_370 = F.relu(t_369)
214
+ t_371 = self.n_Conv_5(t_370)
215
+ t_372 = F.relu(t_371)
216
+ t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
217
+ t_373 = self.n_Conv_6(t_372_padded)
218
+ t_374 = F.relu(t_373)
219
+ t_375 = self.n_Conv_7(t_374)
220
+ t_376 = torch.add(t_375, t_370)
221
+ t_377 = F.relu(t_376)
222
+ t_378 = self.n_Conv_8(t_377)
223
+ t_379 = F.relu(t_378)
224
+ t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
225
+ t_380 = self.n_Conv_9(t_379_padded)
226
+ t_381 = F.relu(t_380)
227
+ t_382 = self.n_Conv_10(t_381)
228
+ t_383 = torch.add(t_382, t_377)
229
+ t_384 = F.relu(t_383)
230
+ t_385 = self.n_Conv_11(t_384)
231
+ t_386 = self.n_Conv_12(t_384)
232
+ t_387 = F.relu(t_386)
233
+ t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
234
+ t_388 = self.n_Conv_13(t_387_padded)
235
+ t_389 = F.relu(t_388)
236
+ t_390 = self.n_Conv_14(t_389)
237
+ t_391 = torch.add(t_390, t_385)
238
+ t_392 = F.relu(t_391)
239
+ t_393 = self.n_Conv_15(t_392)
240
+ t_394 = F.relu(t_393)
241
+ t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
242
+ t_395 = self.n_Conv_16(t_394_padded)
243
+ t_396 = F.relu(t_395)
244
+ t_397 = self.n_Conv_17(t_396)
245
+ t_398 = torch.add(t_397, t_392)
246
+ t_399 = F.relu(t_398)
247
+ t_400 = self.n_Conv_18(t_399)
248
+ t_401 = F.relu(t_400)
249
+ t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
250
+ t_402 = self.n_Conv_19(t_401_padded)
251
+ t_403 = F.relu(t_402)
252
+ t_404 = self.n_Conv_20(t_403)
253
+ t_405 = torch.add(t_404, t_399)
254
+ t_406 = F.relu(t_405)
255
+ t_407 = self.n_Conv_21(t_406)
256
+ t_408 = F.relu(t_407)
257
+ t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
258
+ t_409 = self.n_Conv_22(t_408_padded)
259
+ t_410 = F.relu(t_409)
260
+ t_411 = self.n_Conv_23(t_410)
261
+ t_412 = torch.add(t_411, t_406)
262
+ t_413 = F.relu(t_412)
263
+ t_414 = self.n_Conv_24(t_413)
264
+ t_415 = F.relu(t_414)
265
+ t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
266
+ t_416 = self.n_Conv_25(t_415_padded)
267
+ t_417 = F.relu(t_416)
268
+ t_418 = self.n_Conv_26(t_417)
269
+ t_419 = torch.add(t_418, t_413)
270
+ t_420 = F.relu(t_419)
271
+ t_421 = self.n_Conv_27(t_420)
272
+ t_422 = F.relu(t_421)
273
+ t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
274
+ t_423 = self.n_Conv_28(t_422_padded)
275
+ t_424 = F.relu(t_423)
276
+ t_425 = self.n_Conv_29(t_424)
277
+ t_426 = torch.add(t_425, t_420)
278
+ t_427 = F.relu(t_426)
279
+ t_428 = self.n_Conv_30(t_427)
280
+ t_429 = F.relu(t_428)
281
+ t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
282
+ t_430 = self.n_Conv_31(t_429_padded)
283
+ t_431 = F.relu(t_430)
284
+ t_432 = self.n_Conv_32(t_431)
285
+ t_433 = torch.add(t_432, t_427)
286
+ t_434 = F.relu(t_433)
287
+ t_435 = self.n_Conv_33(t_434)
288
+ t_436 = F.relu(t_435)
289
+ t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
290
+ t_437 = self.n_Conv_34(t_436_padded)
291
+ t_438 = F.relu(t_437)
292
+ t_439 = self.n_Conv_35(t_438)
293
+ t_440 = torch.add(t_439, t_434)
294
+ t_441 = F.relu(t_440)
295
+ t_442 = self.n_Conv_36(t_441)
296
+ t_443 = self.n_Conv_37(t_441)
297
+ t_444 = F.relu(t_443)
298
+ t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
299
+ t_445 = self.n_Conv_38(t_444_padded)
300
+ t_446 = F.relu(t_445)
301
+ t_447 = self.n_Conv_39(t_446)
302
+ t_448 = torch.add(t_447, t_442)
303
+ t_449 = F.relu(t_448)
304
+ t_450 = self.n_Conv_40(t_449)
305
+ t_451 = F.relu(t_450)
306
+ t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
307
+ t_452 = self.n_Conv_41(t_451_padded)
308
+ t_453 = F.relu(t_452)
309
+ t_454 = self.n_Conv_42(t_453)
310
+ t_455 = torch.add(t_454, t_449)
311
+ t_456 = F.relu(t_455)
312
+ t_457 = self.n_Conv_43(t_456)
313
+ t_458 = F.relu(t_457)
314
+ t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
315
+ t_459 = self.n_Conv_44(t_458_padded)
316
+ t_460 = F.relu(t_459)
317
+ t_461 = self.n_Conv_45(t_460)
318
+ t_462 = torch.add(t_461, t_456)
319
+ t_463 = F.relu(t_462)
320
+ t_464 = self.n_Conv_46(t_463)
321
+ t_465 = F.relu(t_464)
322
+ t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
323
+ t_466 = self.n_Conv_47(t_465_padded)
324
+ t_467 = F.relu(t_466)
325
+ t_468 = self.n_Conv_48(t_467)
326
+ t_469 = torch.add(t_468, t_463)
327
+ t_470 = F.relu(t_469)
328
+ t_471 = self.n_Conv_49(t_470)
329
+ t_472 = F.relu(t_471)
330
+ t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
331
+ t_473 = self.n_Conv_50(t_472_padded)
332
+ t_474 = F.relu(t_473)
333
+ t_475 = self.n_Conv_51(t_474)
334
+ t_476 = torch.add(t_475, t_470)
335
+ t_477 = F.relu(t_476)
336
+ t_478 = self.n_Conv_52(t_477)
337
+ t_479 = F.relu(t_478)
338
+ t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
339
+ t_480 = self.n_Conv_53(t_479_padded)
340
+ t_481 = F.relu(t_480)
341
+ t_482 = self.n_Conv_54(t_481)
342
+ t_483 = torch.add(t_482, t_477)
343
+ t_484 = F.relu(t_483)
344
+ t_485 = self.n_Conv_55(t_484)
345
+ t_486 = F.relu(t_485)
346
+ t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
347
+ t_487 = self.n_Conv_56(t_486_padded)
348
+ t_488 = F.relu(t_487)
349
+ t_489 = self.n_Conv_57(t_488)
350
+ t_490 = torch.add(t_489, t_484)
351
+ t_491 = F.relu(t_490)
352
+ t_492 = self.n_Conv_58(t_491)
353
+ t_493 = F.relu(t_492)
354
+ t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
355
+ t_494 = self.n_Conv_59(t_493_padded)
356
+ t_495 = F.relu(t_494)
357
+ t_496 = self.n_Conv_60(t_495)
358
+ t_497 = torch.add(t_496, t_491)
359
+ t_498 = F.relu(t_497)
360
+ t_499 = self.n_Conv_61(t_498)
361
+ t_500 = F.relu(t_499)
362
+ t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
363
+ t_501 = self.n_Conv_62(t_500_padded)
364
+ t_502 = F.relu(t_501)
365
+ t_503 = self.n_Conv_63(t_502)
366
+ t_504 = torch.add(t_503, t_498)
367
+ t_505 = F.relu(t_504)
368
+ t_506 = self.n_Conv_64(t_505)
369
+ t_507 = F.relu(t_506)
370
+ t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
371
+ t_508 = self.n_Conv_65(t_507_padded)
372
+ t_509 = F.relu(t_508)
373
+ t_510 = self.n_Conv_66(t_509)
374
+ t_511 = torch.add(t_510, t_505)
375
+ t_512 = F.relu(t_511)
376
+ t_513 = self.n_Conv_67(t_512)
377
+ t_514 = F.relu(t_513)
378
+ t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
379
+ t_515 = self.n_Conv_68(t_514_padded)
380
+ t_516 = F.relu(t_515)
381
+ t_517 = self.n_Conv_69(t_516)
382
+ t_518 = torch.add(t_517, t_512)
383
+ t_519 = F.relu(t_518)
384
+ t_520 = self.n_Conv_70(t_519)
385
+ t_521 = F.relu(t_520)
386
+ t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
387
+ t_522 = self.n_Conv_71(t_521_padded)
388
+ t_523 = F.relu(t_522)
389
+ t_524 = self.n_Conv_72(t_523)
390
+ t_525 = torch.add(t_524, t_519)
391
+ t_526 = F.relu(t_525)
392
+ t_527 = self.n_Conv_73(t_526)
393
+ t_528 = F.relu(t_527)
394
+ t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
395
+ t_529 = self.n_Conv_74(t_528_padded)
396
+ t_530 = F.relu(t_529)
397
+ t_531 = self.n_Conv_75(t_530)
398
+ t_532 = torch.add(t_531, t_526)
399
+ t_533 = F.relu(t_532)
400
+ t_534 = self.n_Conv_76(t_533)
401
+ t_535 = F.relu(t_534)
402
+ t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
403
+ t_536 = self.n_Conv_77(t_535_padded)
404
+ t_537 = F.relu(t_536)
405
+ t_538 = self.n_Conv_78(t_537)
406
+ t_539 = torch.add(t_538, t_533)
407
+ t_540 = F.relu(t_539)
408
+ t_541 = self.n_Conv_79(t_540)
409
+ t_542 = F.relu(t_541)
410
+ t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
411
+ t_543 = self.n_Conv_80(t_542_padded)
412
+ t_544 = F.relu(t_543)
413
+ t_545 = self.n_Conv_81(t_544)
414
+ t_546 = torch.add(t_545, t_540)
415
+ t_547 = F.relu(t_546)
416
+ t_548 = self.n_Conv_82(t_547)
417
+ t_549 = F.relu(t_548)
418
+ t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
419
+ t_550 = self.n_Conv_83(t_549_padded)
420
+ t_551 = F.relu(t_550)
421
+ t_552 = self.n_Conv_84(t_551)
422
+ t_553 = torch.add(t_552, t_547)
423
+ t_554 = F.relu(t_553)
424
+ t_555 = self.n_Conv_85(t_554)
425
+ t_556 = F.relu(t_555)
426
+ t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
427
+ t_557 = self.n_Conv_86(t_556_padded)
428
+ t_558 = F.relu(t_557)
429
+ t_559 = self.n_Conv_87(t_558)
430
+ t_560 = torch.add(t_559, t_554)
431
+ t_561 = F.relu(t_560)
432
+ t_562 = self.n_Conv_88(t_561)
433
+ t_563 = F.relu(t_562)
434
+ t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
435
+ t_564 = self.n_Conv_89(t_563_padded)
436
+ t_565 = F.relu(t_564)
437
+ t_566 = self.n_Conv_90(t_565)
438
+ t_567 = torch.add(t_566, t_561)
439
+ t_568 = F.relu(t_567)
440
+ t_569 = self.n_Conv_91(t_568)
441
+ t_570 = F.relu(t_569)
442
+ t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
443
+ t_571 = self.n_Conv_92(t_570_padded)
444
+ t_572 = F.relu(t_571)
445
+ t_573 = self.n_Conv_93(t_572)
446
+ t_574 = torch.add(t_573, t_568)
447
+ t_575 = F.relu(t_574)
448
+ t_576 = self.n_Conv_94(t_575)
449
+ t_577 = F.relu(t_576)
450
+ t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
451
+ t_578 = self.n_Conv_95(t_577_padded)
452
+ t_579 = F.relu(t_578)
453
+ t_580 = self.n_Conv_96(t_579)
454
+ t_581 = torch.add(t_580, t_575)
455
+ t_582 = F.relu(t_581)
456
+ t_583 = self.n_Conv_97(t_582)
457
+ t_584 = F.relu(t_583)
458
+ t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
459
+ t_585 = self.n_Conv_98(t_584_padded)
460
+ t_586 = F.relu(t_585)
461
+ t_587 = self.n_Conv_99(t_586)
462
+ t_588 = self.n_Conv_100(t_582)
463
+ t_589 = torch.add(t_587, t_588)
464
+ t_590 = F.relu(t_589)
465
+ t_591 = self.n_Conv_101(t_590)
466
+ t_592 = F.relu(t_591)
467
+ t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
468
+ t_593 = self.n_Conv_102(t_592_padded)
469
+ t_594 = F.relu(t_593)
470
+ t_595 = self.n_Conv_103(t_594)
471
+ t_596 = torch.add(t_595, t_590)
472
+ t_597 = F.relu(t_596)
473
+ t_598 = self.n_Conv_104(t_597)
474
+ t_599 = F.relu(t_598)
475
+ t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
476
+ t_600 = self.n_Conv_105(t_599_padded)
477
+ t_601 = F.relu(t_600)
478
+ t_602 = self.n_Conv_106(t_601)
479
+ t_603 = torch.add(t_602, t_597)
480
+ t_604 = F.relu(t_603)
481
+ t_605 = self.n_Conv_107(t_604)
482
+ t_606 = F.relu(t_605)
483
+ t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
484
+ t_607 = self.n_Conv_108(t_606_padded)
485
+ t_608 = F.relu(t_607)
486
+ t_609 = self.n_Conv_109(t_608)
487
+ t_610 = torch.add(t_609, t_604)
488
+ t_611 = F.relu(t_610)
489
+ t_612 = self.n_Conv_110(t_611)
490
+ t_613 = F.relu(t_612)
491
+ t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
492
+ t_614 = self.n_Conv_111(t_613_padded)
493
+ t_615 = F.relu(t_614)
494
+ t_616 = self.n_Conv_112(t_615)
495
+ t_617 = torch.add(t_616, t_611)
496
+ t_618 = F.relu(t_617)
497
+ t_619 = self.n_Conv_113(t_618)
498
+ t_620 = F.relu(t_619)
499
+ t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
500
+ t_621 = self.n_Conv_114(t_620_padded)
501
+ t_622 = F.relu(t_621)
502
+ t_623 = self.n_Conv_115(t_622)
503
+ t_624 = torch.add(t_623, t_618)
504
+ t_625 = F.relu(t_624)
505
+ t_626 = self.n_Conv_116(t_625)
506
+ t_627 = F.relu(t_626)
507
+ t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
508
+ t_628 = self.n_Conv_117(t_627_padded)
509
+ t_629 = F.relu(t_628)
510
+ t_630 = self.n_Conv_118(t_629)
511
+ t_631 = torch.add(t_630, t_625)
512
+ t_632 = F.relu(t_631)
513
+ t_633 = self.n_Conv_119(t_632)
514
+ t_634 = F.relu(t_633)
515
+ t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
516
+ t_635 = self.n_Conv_120(t_634_padded)
517
+ t_636 = F.relu(t_635)
518
+ t_637 = self.n_Conv_121(t_636)
519
+ t_638 = torch.add(t_637, t_632)
520
+ t_639 = F.relu(t_638)
521
+ t_640 = self.n_Conv_122(t_639)
522
+ t_641 = F.relu(t_640)
523
+ t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
524
+ t_642 = self.n_Conv_123(t_641_padded)
525
+ t_643 = F.relu(t_642)
526
+ t_644 = self.n_Conv_124(t_643)
527
+ t_645 = torch.add(t_644, t_639)
528
+ t_646 = F.relu(t_645)
529
+ t_647 = self.n_Conv_125(t_646)
530
+ t_648 = F.relu(t_647)
531
+ t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
532
+ t_649 = self.n_Conv_126(t_648_padded)
533
+ t_650 = F.relu(t_649)
534
+ t_651 = self.n_Conv_127(t_650)
535
+ t_652 = torch.add(t_651, t_646)
536
+ t_653 = F.relu(t_652)
537
+ t_654 = self.n_Conv_128(t_653)
538
+ t_655 = F.relu(t_654)
539
+ t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
540
+ t_656 = self.n_Conv_129(t_655_padded)
541
+ t_657 = F.relu(t_656)
542
+ t_658 = self.n_Conv_130(t_657)
543
+ t_659 = torch.add(t_658, t_653)
544
+ t_660 = F.relu(t_659)
545
+ t_661 = self.n_Conv_131(t_660)
546
+ t_662 = F.relu(t_661)
547
+ t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
548
+ t_663 = self.n_Conv_132(t_662_padded)
549
+ t_664 = F.relu(t_663)
550
+ t_665 = self.n_Conv_133(t_664)
551
+ t_666 = torch.add(t_665, t_660)
552
+ t_667 = F.relu(t_666)
553
+ t_668 = self.n_Conv_134(t_667)
554
+ t_669 = F.relu(t_668)
555
+ t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
556
+ t_670 = self.n_Conv_135(t_669_padded)
557
+ t_671 = F.relu(t_670)
558
+ t_672 = self.n_Conv_136(t_671)
559
+ t_673 = torch.add(t_672, t_667)
560
+ t_674 = F.relu(t_673)
561
+ t_675 = self.n_Conv_137(t_674)
562
+ t_676 = F.relu(t_675)
563
+ t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
564
+ t_677 = self.n_Conv_138(t_676_padded)
565
+ t_678 = F.relu(t_677)
566
+ t_679 = self.n_Conv_139(t_678)
567
+ t_680 = torch.add(t_679, t_674)
568
+ t_681 = F.relu(t_680)
569
+ t_682 = self.n_Conv_140(t_681)
570
+ t_683 = F.relu(t_682)
571
+ t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
572
+ t_684 = self.n_Conv_141(t_683_padded)
573
+ t_685 = F.relu(t_684)
574
+ t_686 = self.n_Conv_142(t_685)
575
+ t_687 = torch.add(t_686, t_681)
576
+ t_688 = F.relu(t_687)
577
+ t_689 = self.n_Conv_143(t_688)
578
+ t_690 = F.relu(t_689)
579
+ t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
580
+ t_691 = self.n_Conv_144(t_690_padded)
581
+ t_692 = F.relu(t_691)
582
+ t_693 = self.n_Conv_145(t_692)
583
+ t_694 = torch.add(t_693, t_688)
584
+ t_695 = F.relu(t_694)
585
+ t_696 = self.n_Conv_146(t_695)
586
+ t_697 = F.relu(t_696)
587
+ t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
588
+ t_698 = self.n_Conv_147(t_697_padded)
589
+ t_699 = F.relu(t_698)
590
+ t_700 = self.n_Conv_148(t_699)
591
+ t_701 = torch.add(t_700, t_695)
592
+ t_702 = F.relu(t_701)
593
+ t_703 = self.n_Conv_149(t_702)
594
+ t_704 = F.relu(t_703)
595
+ t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
596
+ t_705 = self.n_Conv_150(t_704_padded)
597
+ t_706 = F.relu(t_705)
598
+ t_707 = self.n_Conv_151(t_706)
599
+ t_708 = torch.add(t_707, t_702)
600
+ t_709 = F.relu(t_708)
601
+ t_710 = self.n_Conv_152(t_709)
602
+ t_711 = F.relu(t_710)
603
+ t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
604
+ t_712 = self.n_Conv_153(t_711_padded)
605
+ t_713 = F.relu(t_712)
606
+ t_714 = self.n_Conv_154(t_713)
607
+ t_715 = torch.add(t_714, t_709)
608
+ t_716 = F.relu(t_715)
609
+ t_717 = self.n_Conv_155(t_716)
610
+ t_718 = F.relu(t_717)
611
+ t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
612
+ t_719 = self.n_Conv_156(t_718_padded)
613
+ t_720 = F.relu(t_719)
614
+ t_721 = self.n_Conv_157(t_720)
615
+ t_722 = torch.add(t_721, t_716)
616
+ t_723 = F.relu(t_722)
617
+ t_724 = self.n_Conv_158(t_723)
618
+ t_725 = self.n_Conv_159(t_723)
619
+ t_726 = F.relu(t_725)
620
+ t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
621
+ t_727 = self.n_Conv_160(t_726_padded)
622
+ t_728 = F.relu(t_727)
623
+ t_729 = self.n_Conv_161(t_728)
624
+ t_730 = torch.add(t_729, t_724)
625
+ t_731 = F.relu(t_730)
626
+ t_732 = self.n_Conv_162(t_731)
627
+ t_733 = F.relu(t_732)
628
+ t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
629
+ t_734 = self.n_Conv_163(t_733_padded)
630
+ t_735 = F.relu(t_734)
631
+ t_736 = self.n_Conv_164(t_735)
632
+ t_737 = torch.add(t_736, t_731)
633
+ t_738 = F.relu(t_737)
634
+ t_739 = self.n_Conv_165(t_738)
635
+ t_740 = F.relu(t_739)
636
+ t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
637
+ t_741 = self.n_Conv_166(t_740_padded)
638
+ t_742 = F.relu(t_741)
639
+ t_743 = self.n_Conv_167(t_742)
640
+ t_744 = torch.add(t_743, t_738)
641
+ t_745 = F.relu(t_744)
642
+ t_746 = self.n_Conv_168(t_745)
643
+ t_747 = self.n_Conv_169(t_745)
644
+ t_748 = F.relu(t_747)
645
+ t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
646
+ t_749 = self.n_Conv_170(t_748_padded)
647
+ t_750 = F.relu(t_749)
648
+ t_751 = self.n_Conv_171(t_750)
649
+ t_752 = torch.add(t_751, t_746)
650
+ t_753 = F.relu(t_752)
651
+ t_754 = self.n_Conv_172(t_753)
652
+ t_755 = F.relu(t_754)
653
+ t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
654
+ t_756 = self.n_Conv_173(t_755_padded)
655
+ t_757 = F.relu(t_756)
656
+ t_758 = self.n_Conv_174(t_757)
657
+ t_759 = torch.add(t_758, t_753)
658
+ t_760 = F.relu(t_759)
659
+ t_761 = self.n_Conv_175(t_760)
660
+ t_762 = F.relu(t_761)
661
+ t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
662
+ t_763 = self.n_Conv_176(t_762_padded)
663
+ t_764 = F.relu(t_763)
664
+ t_765 = self.n_Conv_177(t_764)
665
+ t_766 = torch.add(t_765, t_760)
666
+ t_767 = F.relu(t_766)
667
+ t_768 = self.n_Conv_178(t_767)
668
+ t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
669
+ t_770 = torch.squeeze(t_769, 3)
670
+ t_770 = torch.squeeze(t_770, 2)
671
+ t_771 = torch.sigmoid(t_770)
672
+ return t_771
673
+
674
+ def load_state_dict(self, state_dict, **kwargs):
675
+ self.tags = state_dict.get('tags', [])
676
+
677
+ super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
678
+
sd/stable-diffusion-webui/modules/devices.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import contextlib
3
+ import torch
4
+ from modules import errors
5
+
6
+ if sys.platform == "darwin":
7
+ from modules import mac_specific
8
+
9
+
10
+ def has_mps() -> bool:
11
+ if sys.platform != "darwin":
12
+ return False
13
+ else:
14
+ return mac_specific.has_mps
15
+
16
+ def extract_device_id(args, name):
17
+ for x in range(len(args)):
18
+ if name in args[x]:
19
+ return args[x + 1]
20
+
21
+ return None
22
+
23
+
24
+ def get_cuda_device_string():
25
+ from modules import shared
26
+
27
+ if shared.cmd_opts.device_id is not None:
28
+ return f"cuda:{shared.cmd_opts.device_id}"
29
+
30
+ return "cuda"
31
+
32
+
33
+ def get_optimal_device_name():
34
+ if torch.cuda.is_available():
35
+ return get_cuda_device_string()
36
+
37
+ if has_mps():
38
+ return "mps"
39
+
40
+ return "cpu"
41
+
42
+
43
+ def get_optimal_device():
44
+ return torch.device(get_optimal_device_name())
45
+
46
+
47
+ def get_device_for(task):
48
+ from modules import shared
49
+
50
+ if task in shared.cmd_opts.use_cpu:
51
+ return cpu
52
+
53
+ return get_optimal_device()
54
+
55
+
56
+ def torch_gc():
57
+ if torch.cuda.is_available():
58
+ with torch.cuda.device(get_cuda_device_string()):
59
+ torch.cuda.empty_cache()
60
+ torch.cuda.ipc_collect()
61
+
62
+
63
+ def enable_tf32():
64
+ if torch.cuda.is_available():
65
+
66
+ # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
67
+ # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
68
+ if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
69
+ torch.backends.cudnn.benchmark = True
70
+
71
+ torch.backends.cuda.matmul.allow_tf32 = True
72
+ torch.backends.cudnn.allow_tf32 = True
73
+
74
+
75
+
76
+ errors.run(enable_tf32, "Enabling TF32")
77
+
78
+ cpu = torch.device("cpu")
79
+ device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
80
+ dtype = torch.float16
81
+ dtype_vae = torch.float16
82
+ dtype_unet = torch.float16
83
+ unet_needs_upcast = False
84
+
85
+
86
+ def cond_cast_unet(input):
87
+ return input.to(dtype_unet) if unet_needs_upcast else input
88
+
89
+
90
+ def cond_cast_float(input):
91
+ return input.float() if unet_needs_upcast else input
92
+
93
+
94
+ def randn(seed, shape):
95
+ torch.manual_seed(seed)
96
+ if device.type == 'mps':
97
+ return torch.randn(shape, device=cpu).to(device)
98
+ return torch.randn(shape, device=device)
99
+
100
+
101
+ def randn_without_seed(shape):
102
+ if device.type == 'mps':
103
+ return torch.randn(shape, device=cpu).to(device)
104
+ return torch.randn(shape, device=device)
105
+
106
+
107
+ def autocast(disable=False):
108
+ from modules import shared
109
+
110
+ if disable:
111
+ return contextlib.nullcontext()
112
+
113
+ if dtype == torch.float32 or shared.cmd_opts.precision == "full":
114
+ return contextlib.nullcontext()
115
+
116
+ return torch.autocast("cuda")
117
+
118
+
119
+ def without_autocast(disable=False):
120
+ return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
121
+
122
+
123
+ class NansException(Exception):
124
+ pass
125
+
126
+
127
+ def test_for_nans(x, where):
128
+ from modules import shared
129
+
130
+ if shared.cmd_opts.disable_nan_check:
131
+ return
132
+
133
+ if not torch.all(torch.isnan(x)).item():
134
+ return
135
+
136
+ if where == "unet":
137
+ message = "A tensor with all NaNs was produced in Unet."
138
+
139
+ if not shared.cmd_opts.no_half:
140
+ message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
141
+
142
+ elif where == "vae":
143
+ message = "A tensor with all NaNs was produced in VAE."
144
+
145
+ if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
146
+ message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
147
+ else:
148
+ message = "A tensor with all NaNs was produced."
149
+
150
+ message += " Use --disable-nan-check commandline argument to disable this check."
151
+
152
+ raise NansException(message)
sd/stable-diffusion-webui/modules/errors.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import traceback
3
+
4
+
5
+ def print_error_explanation(message):
6
+ lines = message.strip().split("\n")
7
+ max_len = max([len(x) for x in lines])
8
+
9
+ print('=' * max_len, file=sys.stderr)
10
+ for line in lines:
11
+ print(line, file=sys.stderr)
12
+ print('=' * max_len, file=sys.stderr)
13
+
14
+
15
+ def display(e: Exception, task):
16
+ print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
17
+ print(traceback.format_exc(), file=sys.stderr)
18
+
19
+ message = str(e)
20
+ if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
21
+ print_error_explanation("""
22
+ The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
23
+ See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
24
+ """)
25
+
26
+
27
+ already_displayed = {}
28
+
29
+
30
+ def display_once(e: Exception, task):
31
+ if task in already_displayed:
32
+ return
33
+
34
+ display(e, task)
35
+
36
+ already_displayed[task] = 1
37
+
38
+
39
+ def run(code, task):
40
+ try:
41
+ code()
42
+ except Exception as e:
43
+ display(task, e)
sd/stable-diffusion-webui/modules/esrgan_model.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from basicsr.utils.download_util import load_file_from_url
7
+
8
+ import modules.esrgan_model_arch as arch
9
+ from modules import shared, modelloader, images, devices
10
+ from modules.upscaler import Upscaler, UpscalerData
11
+ from modules.shared import opts
12
+
13
+
14
+
15
+ def mod2normal(state_dict):
16
+ # this code is copied from https://github.com/victorca25/iNNfer
17
+ if 'conv_first.weight' in state_dict:
18
+ crt_net = {}
19
+ items = []
20
+ for k, v in state_dict.items():
21
+ items.append(k)
22
+
23
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
24
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
25
+
26
+ for k in items.copy():
27
+ if 'RDB' in k:
28
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
29
+ if '.weight' in k:
30
+ ori_k = ori_k.replace('.weight', '.0.weight')
31
+ elif '.bias' in k:
32
+ ori_k = ori_k.replace('.bias', '.0.bias')
33
+ crt_net[ori_k] = state_dict[k]
34
+ items.remove(k)
35
+
36
+ crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
37
+ crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
38
+ crt_net['model.3.weight'] = state_dict['upconv1.weight']
39
+ crt_net['model.3.bias'] = state_dict['upconv1.bias']
40
+ crt_net['model.6.weight'] = state_dict['upconv2.weight']
41
+ crt_net['model.6.bias'] = state_dict['upconv2.bias']
42
+ crt_net['model.8.weight'] = state_dict['HRconv.weight']
43
+ crt_net['model.8.bias'] = state_dict['HRconv.bias']
44
+ crt_net['model.10.weight'] = state_dict['conv_last.weight']
45
+ crt_net['model.10.bias'] = state_dict['conv_last.bias']
46
+ state_dict = crt_net
47
+ return state_dict
48
+
49
+
50
+ def resrgan2normal(state_dict, nb=23):
51
+ # this code is copied from https://github.com/victorca25/iNNfer
52
+ if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
53
+ re8x = 0
54
+ crt_net = {}
55
+ items = []
56
+ for k, v in state_dict.items():
57
+ items.append(k)
58
+
59
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
60
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
61
+
62
+ for k in items.copy():
63
+ if "rdb" in k:
64
+ ori_k = k.replace('body.', 'model.1.sub.')
65
+ ori_k = ori_k.replace('.rdb', '.RDB')
66
+ if '.weight' in k:
67
+ ori_k = ori_k.replace('.weight', '.0.weight')
68
+ elif '.bias' in k:
69
+ ori_k = ori_k.replace('.bias', '.0.bias')
70
+ crt_net[ori_k] = state_dict[k]
71
+ items.remove(k)
72
+
73
+ crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
74
+ crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
75
+ crt_net['model.3.weight'] = state_dict['conv_up1.weight']
76
+ crt_net['model.3.bias'] = state_dict['conv_up1.bias']
77
+ crt_net['model.6.weight'] = state_dict['conv_up2.weight']
78
+ crt_net['model.6.bias'] = state_dict['conv_up2.bias']
79
+
80
+ if 'conv_up3.weight' in state_dict:
81
+ # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
82
+ re8x = 3
83
+ crt_net['model.9.weight'] = state_dict['conv_up3.weight']
84
+ crt_net['model.9.bias'] = state_dict['conv_up3.bias']
85
+
86
+ crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
87
+ crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
88
+ crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
89
+ crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
90
+
91
+ state_dict = crt_net
92
+ return state_dict
93
+
94
+
95
+ def infer_params(state_dict):
96
+ # this code is copied from https://github.com/victorca25/iNNfer
97
+ scale2x = 0
98
+ scalemin = 6
99
+ n_uplayer = 0
100
+ plus = False
101
+
102
+ for block in list(state_dict):
103
+ parts = block.split(".")
104
+ n_parts = len(parts)
105
+ if n_parts == 5 and parts[2] == "sub":
106
+ nb = int(parts[3])
107
+ elif n_parts == 3:
108
+ part_num = int(parts[1])
109
+ if (part_num > scalemin
110
+ and parts[0] == "model"
111
+ and parts[2] == "weight"):
112
+ scale2x += 1
113
+ if part_num > n_uplayer:
114
+ n_uplayer = part_num
115
+ out_nc = state_dict[block].shape[0]
116
+ if not plus and "conv1x1" in block:
117
+ plus = True
118
+
119
+ nf = state_dict["model.0.weight"].shape[0]
120
+ in_nc = state_dict["model.0.weight"].shape[1]
121
+ out_nc = out_nc
122
+ scale = 2 ** scale2x
123
+
124
+ return in_nc, out_nc, nf, nb, plus, scale
125
+
126
+
127
+ class UpscalerESRGAN(Upscaler):
128
+ def __init__(self, dirname):
129
+ self.name = "ESRGAN"
130
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
131
+ self.model_name = "ESRGAN_4x"
132
+ self.scalers = []
133
+ self.user_path = dirname
134
+ super().__init__()
135
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
136
+ scalers = []
137
+ if len(model_paths) == 0:
138
+ scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
139
+ scalers.append(scaler_data)
140
+ for file in model_paths:
141
+ if "http" in file:
142
+ name = self.model_name
143
+ else:
144
+ name = modelloader.friendly_name(file)
145
+
146
+ scaler_data = UpscalerData(name, file, self, 4)
147
+ self.scalers.append(scaler_data)
148
+
149
+ def do_upscale(self, img, selected_model):
150
+ model = self.load_model(selected_model)
151
+ if model is None:
152
+ return img
153
+ model.to(devices.device_esrgan)
154
+ img = esrgan_upscale(model, img)
155
+ return img
156
+
157
+ def load_model(self, path: str):
158
+ if "http" in path:
159
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
160
+ file_name="%s.pth" % self.model_name,
161
+ progress=True)
162
+ else:
163
+ filename = path
164
+ if not os.path.exists(filename) or filename is None:
165
+ print("Unable to load %s from %s" % (self.model_path, filename))
166
+ return None
167
+
168
+ state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
169
+
170
+ if "params_ema" in state_dict:
171
+ state_dict = state_dict["params_ema"]
172
+ elif "params" in state_dict:
173
+ state_dict = state_dict["params"]
174
+ num_conv = 16 if "realesr-animevideov3" in filename else 32
175
+ model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
176
+ model.load_state_dict(state_dict)
177
+ model.eval()
178
+ return model
179
+
180
+ if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
181
+ nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
182
+ state_dict = resrgan2normal(state_dict, nb)
183
+ elif "conv_first.weight" in state_dict:
184
+ state_dict = mod2normal(state_dict)
185
+ elif "model.0.weight" not in state_dict:
186
+ raise Exception("The file is not a recognized ESRGAN model.")
187
+
188
+ in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
189
+
190
+ model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
191
+ model.load_state_dict(state_dict)
192
+ model.eval()
193
+
194
+ return model
195
+
196
+
197
+ def upscale_without_tiling(model, img):
198
+ img = np.array(img)
199
+ img = img[:, :, ::-1]
200
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
201
+ img = torch.from_numpy(img).float()
202
+ img = img.unsqueeze(0).to(devices.device_esrgan)
203
+ with torch.no_grad():
204
+ output = model(img)
205
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
206
+ output = 255. * np.moveaxis(output, 0, 2)
207
+ output = output.astype(np.uint8)
208
+ output = output[:, :, ::-1]
209
+ return Image.fromarray(output, 'RGB')
210
+
211
+
212
+ def esrgan_upscale(model, img):
213
+ if opts.ESRGAN_tile == 0:
214
+ return upscale_without_tiling(model, img)
215
+
216
+ grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
217
+ newtiles = []
218
+ scale_factor = 1
219
+
220
+ for y, h, row in grid.tiles:
221
+ newrow = []
222
+ for tiledata in row:
223
+ x, w, tile = tiledata
224
+
225
+ output = upscale_without_tiling(model, tile)
226
+ scale_factor = output.width // tile.width
227
+
228
+ newrow.append([x * scale_factor, w * scale_factor, output])
229
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
230
+
231
+ newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
232
+ output = images.combine_grid(newgrid)
233
+ return output
sd/stable-diffusion-webui/modules/esrgan_model_arch.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file is adapted from https://github.com/victorca25/iNNfer
2
+
3
+ from collections import OrderedDict
4
+ import math
5
+ import functools
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ ####################
12
+ # RRDBNet Generator
13
+ ####################
14
+
15
+ class RRDBNet(nn.Module):
16
+ def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
17
+ act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
18
+ finalact=None, gaussian_noise=False, plus=False):
19
+ super(RRDBNet, self).__init__()
20
+ n_upscale = int(math.log(upscale, 2))
21
+ if upscale == 3:
22
+ n_upscale = 1
23
+
24
+ self.resrgan_scale = 0
25
+ if in_nc % 16 == 0:
26
+ self.resrgan_scale = 1
27
+ elif in_nc != 4 and in_nc % 4 == 0:
28
+ self.resrgan_scale = 2
29
+
30
+ fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
31
+ rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
32
+ norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
33
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
34
+ LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
35
+
36
+ if upsample_mode == 'upconv':
37
+ upsample_block = upconv_block
38
+ elif upsample_mode == 'pixelshuffle':
39
+ upsample_block = pixelshuffle_block
40
+ else:
41
+ raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
42
+ if upscale == 3:
43
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
44
+ else:
45
+ upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
46
+ HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
47
+ HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
48
+
49
+ outact = act(finalact) if finalact else None
50
+
51
+ self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
52
+ *upsampler, HR_conv0, HR_conv1, outact)
53
+
54
+ def forward(self, x, outm=None):
55
+ if self.resrgan_scale == 1:
56
+ feat = pixel_unshuffle(x, scale=4)
57
+ elif self.resrgan_scale == 2:
58
+ feat = pixel_unshuffle(x, scale=2)
59
+ else:
60
+ feat = x
61
+
62
+ return self.model(feat)
63
+
64
+
65
+ class RRDB(nn.Module):
66
+ """
67
+ Residual in Residual Dense Block
68
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
69
+ """
70
+
71
+ def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
72
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
73
+ spectral_norm=False, gaussian_noise=False, plus=False):
74
+ super(RRDB, self).__init__()
75
+ # This is for backwards compatibility with existing models
76
+ if nr == 3:
77
+ self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
78
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
79
+ gaussian_noise=gaussian_noise, plus=plus)
80
+ self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
81
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
82
+ gaussian_noise=gaussian_noise, plus=plus)
83
+ self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
84
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
85
+ gaussian_noise=gaussian_noise, plus=plus)
86
+ else:
87
+ RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
88
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
89
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
90
+ self.RDBs = nn.Sequential(*RDB_list)
91
+
92
+ def forward(self, x):
93
+ if hasattr(self, 'RDB1'):
94
+ out = self.RDB1(x)
95
+ out = self.RDB2(out)
96
+ out = self.RDB3(out)
97
+ else:
98
+ out = self.RDBs(x)
99
+ return out * 0.2 + x
100
+
101
+
102
+ class ResidualDenseBlock_5C(nn.Module):
103
+ """
104
+ Residual Dense Block
105
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
106
+ Modified options that can be used:
107
+ - "Partial Convolution based Padding" arXiv:1811.11718
108
+ - "Spectral normalization" arXiv:1802.05957
109
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
110
+ {Rakotonirina} and A. {Rasoanaivo}
111
+ """
112
+
113
+ def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
114
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
115
+ spectral_norm=False, gaussian_noise=False, plus=False):
116
+ super(ResidualDenseBlock_5C, self).__init__()
117
+
118
+ self.noise = GaussianNoise() if gaussian_noise else None
119
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
120
+
121
+ self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
122
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
123
+ spectral_norm=spectral_norm)
124
+ self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
125
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
126
+ spectral_norm=spectral_norm)
127
+ self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
128
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
129
+ spectral_norm=spectral_norm)
130
+ self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
131
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
132
+ spectral_norm=spectral_norm)
133
+ if mode == 'CNA':
134
+ last_act = None
135
+ else:
136
+ last_act = act_type
137
+ self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
138
+ norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
139
+ spectral_norm=spectral_norm)
140
+
141
+ def forward(self, x):
142
+ x1 = self.conv1(x)
143
+ x2 = self.conv2(torch.cat((x, x1), 1))
144
+ if self.conv1x1:
145
+ x2 = x2 + self.conv1x1(x)
146
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
147
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
148
+ if self.conv1x1:
149
+ x4 = x4 + x2
150
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
151
+ if self.noise:
152
+ return self.noise(x5.mul(0.2) + x)
153
+ else:
154
+ return x5 * 0.2 + x
155
+
156
+
157
+ ####################
158
+ # ESRGANplus
159
+ ####################
160
+
161
+ class GaussianNoise(nn.Module):
162
+ def __init__(self, sigma=0.1, is_relative_detach=False):
163
+ super().__init__()
164
+ self.sigma = sigma
165
+ self.is_relative_detach = is_relative_detach
166
+ self.noise = torch.tensor(0, dtype=torch.float)
167
+
168
+ def forward(self, x):
169
+ if self.training and self.sigma != 0:
170
+ self.noise = self.noise.to(x.device)
171
+ scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
172
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
173
+ x = x + sampled_noise
174
+ return x
175
+
176
+ def conv1x1(in_planes, out_planes, stride=1):
177
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
178
+
179
+
180
+ ####################
181
+ # SRVGGNetCompact
182
+ ####################
183
+
184
+ class SRVGGNetCompact(nn.Module):
185
+ """A compact VGG-style network structure for super-resolution.
186
+ This class is copied from https://github.com/xinntao/Real-ESRGAN
187
+ """
188
+
189
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
190
+ super(SRVGGNetCompact, self).__init__()
191
+ self.num_in_ch = num_in_ch
192
+ self.num_out_ch = num_out_ch
193
+ self.num_feat = num_feat
194
+ self.num_conv = num_conv
195
+ self.upscale = upscale
196
+ self.act_type = act_type
197
+
198
+ self.body = nn.ModuleList()
199
+ # the first conv
200
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
201
+ # the first activation
202
+ if act_type == 'relu':
203
+ activation = nn.ReLU(inplace=True)
204
+ elif act_type == 'prelu':
205
+ activation = nn.PReLU(num_parameters=num_feat)
206
+ elif act_type == 'leakyrelu':
207
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
208
+ self.body.append(activation)
209
+
210
+ # the body structure
211
+ for _ in range(num_conv):
212
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
213
+ # activation
214
+ if act_type == 'relu':
215
+ activation = nn.ReLU(inplace=True)
216
+ elif act_type == 'prelu':
217
+ activation = nn.PReLU(num_parameters=num_feat)
218
+ elif act_type == 'leakyrelu':
219
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
220
+ self.body.append(activation)
221
+
222
+ # the last conv
223
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
224
+ # upsample
225
+ self.upsampler = nn.PixelShuffle(upscale)
226
+
227
+ def forward(self, x):
228
+ out = x
229
+ for i in range(0, len(self.body)):
230
+ out = self.body[i](out)
231
+
232
+ out = self.upsampler(out)
233
+ # add the nearest upsampled image, so that the network learns the residual
234
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
235
+ out += base
236
+ return out
237
+
238
+
239
+ ####################
240
+ # Upsampler
241
+ ####################
242
+
243
+ class Upsample(nn.Module):
244
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
245
+ The input data is assumed to be of the form
246
+ `minibatch x channels x [optional depth] x [optional height] x width`.
247
+ """
248
+
249
+ def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
250
+ super(Upsample, self).__init__()
251
+ if isinstance(scale_factor, tuple):
252
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
253
+ else:
254
+ self.scale_factor = float(scale_factor) if scale_factor else None
255
+ self.mode = mode
256
+ self.size = size
257
+ self.align_corners = align_corners
258
+
259
+ def forward(self, x):
260
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
261
+
262
+ def extra_repr(self):
263
+ if self.scale_factor is not None:
264
+ info = 'scale_factor=' + str(self.scale_factor)
265
+ else:
266
+ info = 'size=' + str(self.size)
267
+ info += ', mode=' + self.mode
268
+ return info
269
+
270
+
271
+ def pixel_unshuffle(x, scale):
272
+ """ Pixel unshuffle.
273
+ Args:
274
+ x (Tensor): Input feature with shape (b, c, hh, hw).
275
+ scale (int): Downsample ratio.
276
+ Returns:
277
+ Tensor: the pixel unshuffled feature.
278
+ """
279
+ b, c, hh, hw = x.size()
280
+ out_channel = c * (scale**2)
281
+ assert hh % scale == 0 and hw % scale == 0
282
+ h = hh // scale
283
+ w = hw // scale
284
+ x_view = x.view(b, c, h, scale, w, scale)
285
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
286
+
287
+
288
+ def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
289
+ pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
290
+ """
291
+ Pixel shuffle layer
292
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
293
+ Neural Network, CVPR17)
294
+ """
295
+ conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
296
+ pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
297
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
298
+
299
+ n = norm(norm_type, out_nc) if norm_type else None
300
+ a = act(act_type) if act_type else None
301
+ return sequential(conv, pixel_shuffle, n, a)
302
+
303
+
304
+ def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
305
+ pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
306
+ """ Upconv layer """
307
+ upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
308
+ upsample = Upsample(scale_factor=upscale_factor, mode=mode)
309
+ conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
310
+ pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
311
+ return sequential(upsample, conv)
312
+
313
+
314
+
315
+
316
+
317
+
318
+
319
+
320
+ ####################
321
+ # Basic blocks
322
+ ####################
323
+
324
+
325
+ def make_layer(basic_block, num_basic_block, **kwarg):
326
+ """Make layers by stacking the same blocks.
327
+ Args:
328
+ basic_block (nn.module): nn.module class for basic block. (block)
329
+ num_basic_block (int): number of blocks. (n_layers)
330
+ Returns:
331
+ nn.Sequential: Stacked blocks in nn.Sequential.
332
+ """
333
+ layers = []
334
+ for _ in range(num_basic_block):
335
+ layers.append(basic_block(**kwarg))
336
+ return nn.Sequential(*layers)
337
+
338
+
339
+ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
340
+ """ activation helper """
341
+ act_type = act_type.lower()
342
+ if act_type == 'relu':
343
+ layer = nn.ReLU(inplace)
344
+ elif act_type in ('leakyrelu', 'lrelu'):
345
+ layer = nn.LeakyReLU(neg_slope, inplace)
346
+ elif act_type == 'prelu':
347
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
348
+ elif act_type == 'tanh': # [-1, 1] range output
349
+ layer = nn.Tanh()
350
+ elif act_type == 'sigmoid': # [0, 1] range output
351
+ layer = nn.Sigmoid()
352
+ else:
353
+ raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
354
+ return layer
355
+
356
+
357
+ class Identity(nn.Module):
358
+ def __init__(self, *kwargs):
359
+ super(Identity, self).__init__()
360
+
361
+ def forward(self, x, *kwargs):
362
+ return x
363
+
364
+
365
+ def norm(norm_type, nc):
366
+ """ Return a normalization layer """
367
+ norm_type = norm_type.lower()
368
+ if norm_type == 'batch':
369
+ layer = nn.BatchNorm2d(nc, affine=True)
370
+ elif norm_type == 'instance':
371
+ layer = nn.InstanceNorm2d(nc, affine=False)
372
+ elif norm_type == 'none':
373
+ def norm_layer(x): return Identity()
374
+ else:
375
+ raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
376
+ return layer
377
+
378
+
379
+ def pad(pad_type, padding):
380
+ """ padding layer helper """
381
+ pad_type = pad_type.lower()
382
+ if padding == 0:
383
+ return None
384
+ if pad_type == 'reflect':
385
+ layer = nn.ReflectionPad2d(padding)
386
+ elif pad_type == 'replicate':
387
+ layer = nn.ReplicationPad2d(padding)
388
+ elif pad_type == 'zero':
389
+ layer = nn.ZeroPad2d(padding)
390
+ else:
391
+ raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
392
+ return layer
393
+
394
+
395
+ def get_valid_padding(kernel_size, dilation):
396
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
397
+ padding = (kernel_size - 1) // 2
398
+ return padding
399
+
400
+
401
+ class ShortcutBlock(nn.Module):
402
+ """ Elementwise sum the output of a submodule to its input """
403
+ def __init__(self, submodule):
404
+ super(ShortcutBlock, self).__init__()
405
+ self.sub = submodule
406
+
407
+ def forward(self, x):
408
+ output = x + self.sub(x)
409
+ return output
410
+
411
+ def __repr__(self):
412
+ return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
413
+
414
+
415
+ def sequential(*args):
416
+ """ Flatten Sequential. It unwraps nn.Sequential. """
417
+ if len(args) == 1:
418
+ if isinstance(args[0], OrderedDict):
419
+ raise NotImplementedError('sequential does not support OrderedDict input.')
420
+ return args[0] # No sequential is needed.
421
+ modules = []
422
+ for module in args:
423
+ if isinstance(module, nn.Sequential):
424
+ for submodule in module.children():
425
+ modules.append(submodule)
426
+ elif isinstance(module, nn.Module):
427
+ modules.append(module)
428
+ return nn.Sequential(*modules)
429
+
430
+
431
+ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
432
+ pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
433
+ spectral_norm=False):
434
+ """ Conv layer with padding, normalization, activation """
435
+ assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
436
+ padding = get_valid_padding(kernel_size, dilation)
437
+ p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
438
+ padding = padding if pad_type == 'zero' else 0
439
+
440
+ if convtype=='PartialConv2D':
441
+ c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
442
+ dilation=dilation, bias=bias, groups=groups)
443
+ elif convtype=='DeformConv2D':
444
+ c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
445
+ dilation=dilation, bias=bias, groups=groups)
446
+ elif convtype=='Conv3D':
447
+ c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
448
+ dilation=dilation, bias=bias, groups=groups)
449
+ else:
450
+ c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
451
+ dilation=dilation, bias=bias, groups=groups)
452
+
453
+ if spectral_norm:
454
+ c = nn.utils.spectral_norm(c)
455
+
456
+ a = act(act_type) if act_type else None
457
+ if 'CNA' in mode:
458
+ n = norm(norm_type, out_nc) if norm_type else None
459
+ return sequential(p, c, n, a)
460
+ elif mode == 'NAC':
461
+ if norm_type is None and act_type is not None:
462
+ a = act(act_type, inplace=False)
463
+ n = norm(norm_type, in_nc) if norm_type else None
464
+ return sequential(n, a, p, c)
sd/stable-diffusion-webui/modules/extensions.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import time
6
+ import git
7
+
8
+ from modules import paths, shared
9
+
10
+ extensions = []
11
+ extensions_dir = os.path.join(paths.data_path, "extensions")
12
+ extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
13
+
14
+ if not os.path.exists(extensions_dir):
15
+ os.makedirs(extensions_dir)
16
+
17
+ def active():
18
+ return [x for x in extensions if x.enabled]
19
+
20
+
21
+ class Extension:
22
+ def __init__(self, name, path, enabled=True, is_builtin=False):
23
+ self.name = name
24
+ self.path = path
25
+ self.enabled = enabled
26
+ self.status = ''
27
+ self.can_update = False
28
+ self.is_builtin = is_builtin
29
+ self.version = ''
30
+
31
+ repo = None
32
+ try:
33
+ if os.path.exists(os.path.join(path, ".git")):
34
+ repo = git.Repo(path)
35
+ except Exception:
36
+ print(f"Error reading github repository info from {path}:", file=sys.stderr)
37
+ print(traceback.format_exc(), file=sys.stderr)
38
+
39
+ if repo is None or repo.bare:
40
+ self.remote = None
41
+ else:
42
+ try:
43
+ self.remote = next(repo.remote().urls, None)
44
+ self.status = 'unknown'
45
+ head = repo.head.commit
46
+ ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
47
+ self.version = f'{head.hexsha[:8]} ({ts})'
48
+
49
+ except Exception:
50
+ self.remote = None
51
+
52
+ def list_files(self, subdir, extension):
53
+ from modules import scripts
54
+
55
+ dirpath = os.path.join(self.path, subdir)
56
+ if not os.path.isdir(dirpath):
57
+ return []
58
+
59
+ res = []
60
+ for filename in sorted(os.listdir(dirpath)):
61
+ res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
62
+
63
+ res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
64
+
65
+ return res
66
+
67
+ def check_updates(self):
68
+ repo = git.Repo(self.path)
69
+ for fetch in repo.remote().fetch("--dry-run"):
70
+ if fetch.flags != fetch.HEAD_UPTODATE:
71
+ self.can_update = True
72
+ self.status = "behind"
73
+ return
74
+
75
+ self.can_update = False
76
+ self.status = "latest"
77
+
78
+ def fetch_and_reset_hard(self):
79
+ repo = git.Repo(self.path)
80
+ # Fix: `error: Your local changes to the following files would be overwritten by merge`,
81
+ # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
82
+ repo.git.fetch('--all')
83
+ repo.git.reset('--hard', 'origin')
84
+
85
+
86
+ def list_extensions():
87
+ extensions.clear()
88
+
89
+ if not os.path.isdir(extensions_dir):
90
+ return
91
+
92
+ paths = []
93
+ for dirname in [extensions_dir, extensions_builtin_dir]:
94
+ if not os.path.isdir(dirname):
95
+ return
96
+
97
+ for extension_dirname in sorted(os.listdir(dirname)):
98
+ path = os.path.join(dirname, extension_dirname)
99
+ if not os.path.isdir(path):
100
+ continue
101
+
102
+ paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
103
+
104
+ for dirname, path, is_builtin in paths:
105
+ extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
106
+ extensions.append(extension)
107
+
sd/stable-diffusion-webui/modules/extra_networks.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import defaultdict
3
+
4
+ from modules import errors
5
+
6
+ extra_network_registry = {}
7
+
8
+
9
+ def initialize():
10
+ extra_network_registry.clear()
11
+
12
+
13
+ def register_extra_network(extra_network):
14
+ extra_network_registry[extra_network.name] = extra_network
15
+
16
+
17
+ class ExtraNetworkParams:
18
+ def __init__(self, items=None):
19
+ self.items = items or []
20
+
21
+
22
+ class ExtraNetwork:
23
+ def __init__(self, name):
24
+ self.name = name
25
+
26
+ def activate(self, p, params_list):
27
+ """
28
+ Called by processing on every run. Whatever the extra network is meant to do should be activated here.
29
+ Passes arguments related to this extra network in params_list.
30
+ User passes arguments by specifying this in his prompt:
31
+
32
+ <name:arg1:arg2:arg3>
33
+
34
+ Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
35
+ separated by colon.
36
+
37
+ Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
38
+ in this case, all effects of this extra networks should be disabled.
39
+
40
+ Can be called multiple times before deactivate() - each new call should override the previous call completely.
41
+
42
+ For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
43
+
44
+ > "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
45
+
46
+ params_list will be:
47
+
48
+ [
49
+ ExtraNetworkParams(items=["agm", "1.1"]),
50
+ ExtraNetworkParams(items=["ray"])
51
+ ]
52
+
53
+ """
54
+ raise NotImplementedError
55
+
56
+ def deactivate(self, p):
57
+ """
58
+ Called at the end of processing for housekeeping. No need to do anything here.
59
+ """
60
+
61
+ raise NotImplementedError
62
+
63
+
64
+ def activate(p, extra_network_data):
65
+ """call activate for extra networks in extra_network_data in specified order, then call
66
+ activate for all remaining registered networks with an empty argument list"""
67
+
68
+ for extra_network_name, extra_network_args in extra_network_data.items():
69
+ extra_network = extra_network_registry.get(extra_network_name, None)
70
+ if extra_network is None:
71
+ print(f"Skipping unknown extra network: {extra_network_name}")
72
+ continue
73
+
74
+ try:
75
+ extra_network.activate(p, extra_network_args)
76
+ except Exception as e:
77
+ errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
78
+
79
+ for extra_network_name, extra_network in extra_network_registry.items():
80
+ args = extra_network_data.get(extra_network_name, None)
81
+ if args is not None:
82
+ continue
83
+
84
+ try:
85
+ extra_network.activate(p, [])
86
+ except Exception as e:
87
+ errors.display(e, f"activating extra network {extra_network_name}")
88
+
89
+
90
+ def deactivate(p, extra_network_data):
91
+ """call deactivate for extra networks in extra_network_data in specified order, then call
92
+ deactivate for all remaining registered networks"""
93
+
94
+ for extra_network_name, extra_network_args in extra_network_data.items():
95
+ extra_network = extra_network_registry.get(extra_network_name, None)
96
+ if extra_network is None:
97
+ continue
98
+
99
+ try:
100
+ extra_network.deactivate(p)
101
+ except Exception as e:
102
+ errors.display(e, f"deactivating extra network {extra_network_name}")
103
+
104
+ for extra_network_name, extra_network in extra_network_registry.items():
105
+ args = extra_network_data.get(extra_network_name, None)
106
+ if args is not None:
107
+ continue
108
+
109
+ try:
110
+ extra_network.deactivate(p)
111
+ except Exception as e:
112
+ errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
113
+
114
+
115
+ re_extra_net = re.compile(r"<(\w+):([^>]+)>")
116
+
117
+
118
+ def parse_prompt(prompt):
119
+ res = defaultdict(list)
120
+
121
+ def found(m):
122
+ name = m.group(1)
123
+ args = m.group(2)
124
+
125
+ res[name].append(ExtraNetworkParams(items=args.split(":")))
126
+
127
+ return ""
128
+
129
+ prompt = re.sub(re_extra_net, found, prompt)
130
+
131
+ return prompt, res
132
+
133
+
134
+ def parse_prompts(prompts):
135
+ res = []
136
+ extra_data = None
137
+
138
+ for prompt in prompts:
139
+ updated_prompt, parsed_extra_data = parse_prompt(prompt)
140
+
141
+ if extra_data is None:
142
+ extra_data = parsed_extra_data
143
+
144
+ res.append(updated_prompt)
145
+
146
+ return res, extra_data
147
+
sd/stable-diffusion-webui/modules/extra_networks_hypernet.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import extra_networks, shared, extra_networks
2
+ from modules.hypernetworks import hypernetwork
3
+
4
+
5
+ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
6
+ def __init__(self):
7
+ super().__init__('hypernet')
8
+
9
+ def activate(self, p, params_list):
10
+ additional = shared.opts.sd_hypernetwork
11
+
12
+ if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
13
+ p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
14
+ params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
15
+
16
+ names = []
17
+ multipliers = []
18
+ for params in params_list:
19
+ assert len(params.items) > 0
20
+
21
+ names.append(params.items[0])
22
+ multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
23
+
24
+ hypernetwork.load_hypernetworks(names, multipliers)
25
+
26
+ def deactivate(self, p):
27
+ pass
sd/stable-diffusion-webui/modules/extras.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import shutil
4
+
5
+
6
+ import torch
7
+ import tqdm
8
+
9
+ from modules import shared, images, sd_models, sd_vae, sd_models_config
10
+ from modules.ui_common import plaintext_to_html
11
+ import gradio as gr
12
+ import safetensors.torch
13
+
14
+
15
+ def run_pnginfo(image):
16
+ if image is None:
17
+ return '', '', ''
18
+
19
+ geninfo, items = images.read_info_from_image(image)
20
+ items = {**{'parameters': geninfo}, **items}
21
+
22
+ info = ''
23
+ for key, text in items.items():
24
+ info += f"""
25
+ <div>
26
+ <p><b>{plaintext_to_html(str(key))}</b></p>
27
+ <p>{plaintext_to_html(str(text))}</p>
28
+ </div>
29
+ """.strip()+"\n"
30
+
31
+ if len(info) == 0:
32
+ message = "Nothing found in the image."
33
+ info = f"<div><p>{message}<p></div>"
34
+
35
+ return '', geninfo, info
36
+
37
+
38
+ def create_config(ckpt_result, config_source, a, b, c):
39
+ def config(x):
40
+ res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
41
+ return res if res != shared.sd_default_config else None
42
+
43
+ if config_source == 0:
44
+ cfg = config(a) or config(b) or config(c)
45
+ elif config_source == 1:
46
+ cfg = config(b)
47
+ elif config_source == 2:
48
+ cfg = config(c)
49
+ else:
50
+ cfg = None
51
+
52
+ if cfg is None:
53
+ return
54
+
55
+ filename, _ = os.path.splitext(ckpt_result)
56
+ checkpoint_filename = filename + ".yaml"
57
+
58
+ print("Copying config:")
59
+ print(" from:", cfg)
60
+ print(" to:", checkpoint_filename)
61
+ shutil.copyfile(cfg, checkpoint_filename)
62
+
63
+
64
+ checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
65
+
66
+
67
+ def to_half(tensor, enable):
68
+ if enable and tensor.dtype == torch.float:
69
+ return tensor.half()
70
+
71
+ return tensor
72
+
73
+
74
+ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
75
+ shared.state.begin()
76
+ shared.state.job = 'model-merge'
77
+
78
+ def fail(message):
79
+ shared.state.textinfo = message
80
+ shared.state.end()
81
+ return [*[gr.update() for _ in range(4)], message]
82
+
83
+ def weighted_sum(theta0, theta1, alpha):
84
+ return ((1 - alpha) * theta0) + (alpha * theta1)
85
+
86
+ def get_difference(theta1, theta2):
87
+ return theta1 - theta2
88
+
89
+ def add_difference(theta0, theta1_2_diff, alpha):
90
+ return theta0 + (alpha * theta1_2_diff)
91
+
92
+ def filename_weighted_sum():
93
+ a = primary_model_info.model_name
94
+ b = secondary_model_info.model_name
95
+ Ma = round(1 - multiplier, 2)
96
+ Mb = round(multiplier, 2)
97
+
98
+ return f"{Ma}({a}) + {Mb}({b})"
99
+
100
+ def filename_add_difference():
101
+ a = primary_model_info.model_name
102
+ b = secondary_model_info.model_name
103
+ c = tertiary_model_info.model_name
104
+ M = round(multiplier, 2)
105
+
106
+ return f"{a} + {M}({b} - {c})"
107
+
108
+ def filename_nothing():
109
+ return primary_model_info.model_name
110
+
111
+ theta_funcs = {
112
+ "Weighted sum": (filename_weighted_sum, None, weighted_sum),
113
+ "Add difference": (filename_add_difference, get_difference, add_difference),
114
+ "No interpolation": (filename_nothing, None, None),
115
+ }
116
+ filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
117
+ shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
118
+
119
+ if not primary_model_name:
120
+ return fail("Failed: Merging requires a primary model.")
121
+
122
+ primary_model_info = sd_models.checkpoints_list[primary_model_name]
123
+
124
+ if theta_func2 and not secondary_model_name:
125
+ return fail("Failed: Merging requires a secondary model.")
126
+
127
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
128
+
129
+ if theta_func1 and not tertiary_model_name:
130
+ return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
131
+
132
+ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
133
+
134
+ result_is_inpainting_model = False
135
+ result_is_instruct_pix2pix_model = False
136
+
137
+ if theta_func2:
138
+ shared.state.textinfo = f"Loading B"
139
+ print(f"Loading {secondary_model_info.filename}...")
140
+ theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
141
+ else:
142
+ theta_1 = None
143
+
144
+ if theta_func1:
145
+ shared.state.textinfo = f"Loading C"
146
+ print(f"Loading {tertiary_model_info.filename}...")
147
+ theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
148
+
149
+ shared.state.textinfo = 'Merging B and C'
150
+ shared.state.sampling_steps = len(theta_1.keys())
151
+ for key in tqdm.tqdm(theta_1.keys()):
152
+ if key in checkpoint_dict_skip_on_merge:
153
+ continue
154
+
155
+ if 'model' in key:
156
+ if key in theta_2:
157
+ t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
158
+ theta_1[key] = theta_func1(theta_1[key], t2)
159
+ else:
160
+ theta_1[key] = torch.zeros_like(theta_1[key])
161
+
162
+ shared.state.sampling_step += 1
163
+ del theta_2
164
+
165
+ shared.state.nextjob()
166
+
167
+ shared.state.textinfo = f"Loading {primary_model_info.filename}..."
168
+ print(f"Loading {primary_model_info.filename}...")
169
+ theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
170
+
171
+ print("Merging...")
172
+ shared.state.textinfo = 'Merging A and B'
173
+ shared.state.sampling_steps = len(theta_0.keys())
174
+ for key in tqdm.tqdm(theta_0.keys()):
175
+ if theta_1 and 'model' in key and key in theta_1:
176
+
177
+ if key in checkpoint_dict_skip_on_merge:
178
+ continue
179
+
180
+ a = theta_0[key]
181
+ b = theta_1[key]
182
+
183
+ # this enables merging an inpainting model (A) with another one (B);
184
+ # where normal model would have 4 channels, for latenst space, inpainting model would
185
+ # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
186
+ if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
187
+ if a.shape[1] == 4 and b.shape[1] == 9:
188
+ raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
189
+ if a.shape[1] == 4 and b.shape[1] == 8:
190
+ raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
191
+
192
+ if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
193
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
194
+ result_is_instruct_pix2pix_model = True
195
+ else:
196
+ assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
197
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
198
+ result_is_inpainting_model = True
199
+ else:
200
+ theta_0[key] = theta_func2(a, b, multiplier)
201
+
202
+ theta_0[key] = to_half(theta_0[key], save_as_half)
203
+
204
+ shared.state.sampling_step += 1
205
+
206
+ del theta_1
207
+
208
+ bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
209
+ if bake_in_vae_filename is not None:
210
+ print(f"Baking in VAE from {bake_in_vae_filename}")
211
+ shared.state.textinfo = 'Baking in VAE'
212
+ vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
213
+
214
+ for key in vae_dict.keys():
215
+ theta_0_key = 'first_stage_model.' + key
216
+ if theta_0_key in theta_0:
217
+ theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
218
+
219
+ del vae_dict
220
+
221
+ if save_as_half and not theta_func2:
222
+ for key in theta_0.keys():
223
+ theta_0[key] = to_half(theta_0[key], save_as_half)
224
+
225
+ if discard_weights:
226
+ regex = re.compile(discard_weights)
227
+ for key in list(theta_0):
228
+ if re.search(regex, key):
229
+ theta_0.pop(key, None)
230
+
231
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
232
+
233
+ filename = filename_generator() if custom_name == '' else custom_name
234
+ filename += ".inpainting" if result_is_inpainting_model else ""
235
+ filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
236
+ filename += "." + checkpoint_format
237
+
238
+ output_modelname = os.path.join(ckpt_dir, filename)
239
+
240
+ shared.state.nextjob()
241
+ shared.state.textinfo = "Saving"
242
+ print(f"Saving to {output_modelname}...")
243
+
244
+ _, extension = os.path.splitext(output_modelname)
245
+ if extension.lower() == ".safetensors":
246
+ safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
247
+ else:
248
+ torch.save(theta_0, output_modelname)
249
+
250
+ sd_models.list_models()
251
+
252
+ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
253
+
254
+ print(f"Checkpoint saved to {output_modelname}.")
255
+ shared.state.textinfo = "Checkpoint saved"
256
+ shared.state.end()
257
+
258
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
sd/stable-diffusion-webui/modules/face_restoration.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import shared
2
+
3
+
4
+ class FaceRestoration:
5
+ def name(self):
6
+ return "None"
7
+
8
+ def restore(self, np_image):
9
+ return np_image
10
+
11
+
12
+ def restore_faces(np_image):
13
+ face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
14
+ if len(face_restorers) == 0:
15
+ return np_image
16
+
17
+ face_restorer = face_restorers[0]
18
+
19
+ return face_restorer.restore(np_image)
sd/stable-diffusion-webui/modules/generation_parameters_copypaste.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import html
3
+ import io
4
+ import math
5
+ import os
6
+ import re
7
+ from pathlib import Path
8
+
9
+ import gradio as gr
10
+ from modules.paths import data_path
11
+ from modules import shared, ui_tempdir, script_callbacks
12
+ import tempfile
13
+ from PIL import Image
14
+
15
+ re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
16
+ re_param = re.compile(re_param_code)
17
+ re_imagesize = re.compile(r"^(\d+)x(\d+)$")
18
+ re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
19
+ type_of_gr_update = type(gr.update())
20
+
21
+ paste_fields = {}
22
+ registered_param_bindings = []
23
+
24
+
25
+ class ParamBinding:
26
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
27
+ self.paste_button = paste_button
28
+ self.tabname = tabname
29
+ self.source_text_component = source_text_component
30
+ self.source_image_component = source_image_component
31
+ self.source_tabname = source_tabname
32
+ self.override_settings_component = override_settings_component
33
+
34
+
35
+ def reset():
36
+ paste_fields.clear()
37
+
38
+
39
+ def quote(text):
40
+ if ',' not in str(text):
41
+ return text
42
+
43
+ text = str(text)
44
+ text = text.replace('\\', '\\\\')
45
+ text = text.replace('"', '\\"')
46
+ return f'"{text}"'
47
+
48
+
49
+ def image_from_url_text(filedata):
50
+ if filedata is None:
51
+ return None
52
+
53
+ if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
54
+ filedata = filedata[0]
55
+
56
+ if type(filedata) == dict and filedata.get("is_file", False):
57
+ filename = filedata["name"]
58
+ is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
59
+ assert is_in_right_dir, 'trying to open image file outside of allowed directories'
60
+
61
+ return Image.open(filename)
62
+
63
+ if type(filedata) == list:
64
+ if len(filedata) == 0:
65
+ return None
66
+
67
+ filedata = filedata[0]
68
+
69
+ if filedata.startswith("data:image/png;base64,"):
70
+ filedata = filedata[len("data:image/png;base64,"):]
71
+
72
+ filedata = base64.decodebytes(filedata.encode('utf-8'))
73
+ image = Image.open(io.BytesIO(filedata))
74
+ return image
75
+
76
+
77
+ def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
78
+ paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
79
+
80
+ # backwards compatibility for existing extensions
81
+ import modules.ui
82
+ if tabname == 'txt2img':
83
+ modules.ui.txt2img_paste_fields = fields
84
+ elif tabname == 'img2img':
85
+ modules.ui.img2img_paste_fields = fields
86
+
87
+
88
+ def create_buttons(tabs_list):
89
+ buttons = {}
90
+ for tab in tabs_list:
91
+ buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
92
+ return buttons
93
+
94
+
95
+ def bind_buttons(buttons, send_image, send_generate_info):
96
+ """old function for backwards compatibility; do not use this, use register_paste_params_button"""
97
+ for tabname, button in buttons.items():
98
+ source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
99
+ source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
100
+
101
+ register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
102
+
103
+
104
+ def register_paste_params_button(binding: ParamBinding):
105
+ registered_param_bindings.append(binding)
106
+
107
+
108
+ def connect_paste_params_buttons():
109
+ binding: ParamBinding
110
+ for binding in registered_param_bindings:
111
+ destination_image_component = paste_fields[binding.tabname]["init_img"]
112
+ fields = paste_fields[binding.tabname]["fields"]
113
+ override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
114
+
115
+ destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
116
+ destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
117
+
118
+ if binding.source_image_component and destination_image_component:
119
+ if isinstance(binding.source_image_component, gr.Gallery):
120
+ func = send_image_and_dimensions if destination_width_component else image_from_url_text
121
+ jsfunc = "extract_image_from_gallery"
122
+ else:
123
+ func = send_image_and_dimensions if destination_width_component else lambda x: x
124
+ jsfunc = None
125
+
126
+ binding.paste_button.click(
127
+ fn=func,
128
+ _js=jsfunc,
129
+ inputs=[binding.source_image_component],
130
+ outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
131
+ )
132
+
133
+ if binding.source_text_component is not None and fields is not None:
134
+ connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
135
+
136
+ if binding.source_tabname is not None and fields is not None:
137
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
138
+ binding.paste_button.click(
139
+ fn=lambda *x: x,
140
+ inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
141
+ outputs=[field for field, name in fields if name in paste_field_names],
142
+ )
143
+
144
+ binding.paste_button.click(
145
+ fn=None,
146
+ _js=f"switch_to_{binding.tabname}",
147
+ inputs=None,
148
+ outputs=None,
149
+ )
150
+
151
+
152
+ def send_image_and_dimensions(x):
153
+ if isinstance(x, Image.Image):
154
+ img = x
155
+ else:
156
+ img = image_from_url_text(x)
157
+
158
+ if shared.opts.send_size and isinstance(img, Image.Image):
159
+ w = img.width
160
+ h = img.height
161
+ else:
162
+ w = gr.update()
163
+ h = gr.update()
164
+
165
+ return img, w, h
166
+
167
+
168
+
169
+ def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
170
+ """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
171
+
172
+ Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
173
+ parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
174
+
175
+ If the infotext has no hash, then a hypernet with the same name will be selected instead.
176
+ """
177
+ hypernet_name = hypernet_name.lower()
178
+ if hypernet_hash is not None:
179
+ # Try to match the hash in the name
180
+ for hypernet_key in shared.hypernetworks.keys():
181
+ result = re_hypernet_hash.search(hypernet_key)
182
+ if result is not None and result[1] == hypernet_hash:
183
+ return hypernet_key
184
+ else:
185
+ # Fall back to a hypernet with the same name
186
+ for hypernet_key in shared.hypernetworks.keys():
187
+ if hypernet_key.lower().startswith(hypernet_name):
188
+ return hypernet_key
189
+
190
+ return None
191
+
192
+
193
+ def restore_old_hires_fix_params(res):
194
+ """for infotexts that specify old First pass size parameter, convert it into
195
+ width, height, and hr scale"""
196
+
197
+ firstpass_width = res.get('First pass size-1', None)
198
+ firstpass_height = res.get('First pass size-2', None)
199
+
200
+ if shared.opts.use_old_hires_fix_width_height:
201
+ hires_width = int(res.get("Hires resize-1", 0))
202
+ hires_height = int(res.get("Hires resize-2", 0))
203
+
204
+ if hires_width and hires_height:
205
+ res['Size-1'] = hires_width
206
+ res['Size-2'] = hires_height
207
+ return
208
+
209
+ if firstpass_width is None or firstpass_height is None:
210
+ return
211
+
212
+ firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
213
+ width = int(res.get("Size-1", 512))
214
+ height = int(res.get("Size-2", 512))
215
+
216
+ if firstpass_width == 0 or firstpass_height == 0:
217
+ from modules import processing
218
+ firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
219
+
220
+ res['Size-1'] = firstpass_width
221
+ res['Size-2'] = firstpass_height
222
+ res['Hires resize-1'] = width
223
+ res['Hires resize-2'] = height
224
+
225
+
226
+ def parse_generation_parameters(x: str):
227
+ """parses generation parameters string, the one you see in text field under the picture in UI:
228
+ ```
229
+ girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
230
+ Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
231
+ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
232
+ ```
233
+
234
+ returns a dict with field values
235
+ """
236
+
237
+ res = {}
238
+
239
+ prompt = ""
240
+ negative_prompt = ""
241
+
242
+ done_with_prompt = False
243
+
244
+ *lines, lastline = x.strip().split("\n")
245
+ if len(re_param.findall(lastline)) < 3:
246
+ lines.append(lastline)
247
+ lastline = ''
248
+
249
+ for i, line in enumerate(lines):
250
+ line = line.strip()
251
+ if line.startswith("Negative prompt:"):
252
+ done_with_prompt = True
253
+ line = line[16:].strip()
254
+
255
+ if done_with_prompt:
256
+ negative_prompt += ("" if negative_prompt == "" else "\n") + line
257
+ else:
258
+ prompt += ("" if prompt == "" else "\n") + line
259
+
260
+ res["Prompt"] = prompt
261
+ res["Negative prompt"] = negative_prompt
262
+
263
+ for k, v in re_param.findall(lastline):
264
+ v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
265
+ m = re_imagesize.match(v)
266
+ if m is not None:
267
+ res[k+"-1"] = m.group(1)
268
+ res[k+"-2"] = m.group(2)
269
+ else:
270
+ res[k] = v
271
+
272
+ # Missing CLIP skip means it was set to 1 (the default)
273
+ if "Clip skip" not in res:
274
+ res["Clip skip"] = "1"
275
+
276
+ hypernet = res.get("Hypernet", None)
277
+ if hypernet is not None:
278
+ res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
279
+
280
+ if "Hires resize-1" not in res:
281
+ res["Hires resize-1"] = 0
282
+ res["Hires resize-2"] = 0
283
+
284
+ restore_old_hires_fix_params(res)
285
+
286
+ return res
287
+
288
+
289
+ settings_map = {}
290
+
291
+ infotext_to_setting_name_mapping = [
292
+ ('Clip skip', 'CLIP_stop_at_last_layers', ),
293
+ ('Conditional mask weight', 'inpainting_mask_weight'),
294
+ ('Model hash', 'sd_model_checkpoint'),
295
+ ('ENSD', 'eta_noise_seed_delta'),
296
+ ('Noise multiplier', 'initial_noise_multiplier'),
297
+ ('Eta', 'eta_ancestral'),
298
+ ('Eta DDIM', 'eta_ddim'),
299
+ ('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
300
+ ]
301
+
302
+
303
+ def create_override_settings_dict(text_pairs):
304
+ """creates processing's override_settings parameters from gradio's multiselect
305
+
306
+ Example input:
307
+ ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
308
+
309
+ Example output:
310
+ {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
311
+ """
312
+
313
+ res = {}
314
+
315
+ params = {}
316
+ for pair in text_pairs:
317
+ k, v = pair.split(":", maxsplit=1)
318
+
319
+ params[k] = v.strip()
320
+
321
+ for param_name, setting_name in infotext_to_setting_name_mapping:
322
+ value = params.get(param_name, None)
323
+
324
+ if value is None:
325
+ continue
326
+
327
+ res[setting_name] = shared.opts.cast_value(setting_name, value)
328
+
329
+ return res
330
+
331
+
332
+ def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
333
+ def paste_func(prompt):
334
+ if not prompt and not shared.cmd_opts.hide_ui_dir_config:
335
+ filename = os.path.join(data_path, "params.txt")
336
+ if os.path.exists(filename):
337
+ with open(filename, "r", encoding="utf8") as file:
338
+ prompt = file.read()
339
+
340
+ params = parse_generation_parameters(prompt)
341
+ script_callbacks.infotext_pasted_callback(prompt, params)
342
+ res = []
343
+
344
+ for output, key in paste_fields:
345
+ if callable(key):
346
+ v = key(params)
347
+ else:
348
+ v = params.get(key, None)
349
+
350
+ if v is None:
351
+ res.append(gr.update())
352
+ elif isinstance(v, type_of_gr_update):
353
+ res.append(v)
354
+ else:
355
+ try:
356
+ valtype = type(output.value)
357
+
358
+ if valtype == bool and v == "False":
359
+ val = False
360
+ else:
361
+ val = valtype(v)
362
+
363
+ res.append(gr.update(value=val))
364
+ except Exception:
365
+ res.append(gr.update())
366
+
367
+ return res
368
+
369
+ if override_settings_component is not None:
370
+ def paste_settings(params):
371
+ vals = {}
372
+
373
+ for param_name, setting_name in infotext_to_setting_name_mapping:
374
+ v = params.get(param_name, None)
375
+ if v is None:
376
+ continue
377
+
378
+ if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
379
+ continue
380
+
381
+ v = shared.opts.cast_value(setting_name, v)
382
+ current_value = getattr(shared.opts, setting_name, None)
383
+
384
+ if v == current_value:
385
+ continue
386
+
387
+ vals[param_name] = v
388
+
389
+ vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
390
+
391
+ return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
392
+
393
+ paste_fields = paste_fields + [(override_settings_component, paste_settings)]
394
+
395
+ button.click(
396
+ fn=paste_func,
397
+ _js=f"recalculate_prompts_{tabname}",
398
+ inputs=[input_comp],
399
+ outputs=[x[0] for x in paste_fields],
400
+ )
401
+
402
+
sd/stable-diffusion-webui/modules/gfpgan_model.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import facexlib
6
+ import gfpgan
7
+
8
+ import modules.face_restoration
9
+ from modules import paths, shared, devices, modelloader
10
+
11
+ model_dir = "GFPGAN"
12
+ user_path = None
13
+ model_path = os.path.join(paths.models_path, model_dir)
14
+ model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
15
+ have_gfpgan = False
16
+ loaded_gfpgan_model = None
17
+
18
+
19
+ def gfpgann():
20
+ global loaded_gfpgan_model
21
+ global model_path
22
+ if loaded_gfpgan_model is not None:
23
+ loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
24
+ return loaded_gfpgan_model
25
+
26
+ if gfpgan_constructor is None:
27
+ return None
28
+
29
+ models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
30
+ if len(models) == 1 and "http" in models[0]:
31
+ model_file = models[0]
32
+ elif len(models) != 0:
33
+ latest_file = max(models, key=os.path.getctime)
34
+ model_file = latest_file
35
+ else:
36
+ print("Unable to load gfpgan model!")
37
+ return None
38
+ if hasattr(facexlib.detection.retinaface, 'device'):
39
+ facexlib.detection.retinaface.device = devices.device_gfpgan
40
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
41
+ loaded_gfpgan_model = model
42
+
43
+ return model
44
+
45
+
46
+ def send_model_to(model, device):
47
+ model.gfpgan.to(device)
48
+ model.face_helper.face_det.to(device)
49
+ model.face_helper.face_parse.to(device)
50
+
51
+
52
+ def gfpgan_fix_faces(np_image):
53
+ model = gfpgann()
54
+ if model is None:
55
+ return np_image
56
+
57
+ send_model_to(model, devices.device_gfpgan)
58
+
59
+ np_image_bgr = np_image[:, :, ::-1]
60
+ cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
61
+ np_image = gfpgan_output_bgr[:, :, ::-1]
62
+
63
+ model.face_helper.clean_all()
64
+
65
+ if shared.opts.face_restoration_unload:
66
+ send_model_to(model, devices.cpu)
67
+
68
+ return np_image
69
+
70
+
71
+ gfpgan_constructor = None
72
+
73
+
74
+ def setup_model(dirname):
75
+ global model_path
76
+ if not os.path.exists(model_path):
77
+ os.makedirs(model_path)
78
+
79
+ try:
80
+ from gfpgan import GFPGANer
81
+ from facexlib import detection, parsing
82
+ global user_path
83
+ global have_gfpgan
84
+ global gfpgan_constructor
85
+
86
+ load_file_from_url_orig = gfpgan.utils.load_file_from_url
87
+ facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
88
+ facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
89
+
90
+ def my_load_file_from_url(**kwargs):
91
+ return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
92
+
93
+ def facex_load_file_from_url(**kwargs):
94
+ return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
95
+
96
+ def facex_load_file_from_url2(**kwargs):
97
+ return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
98
+
99
+ gfpgan.utils.load_file_from_url = my_load_file_from_url
100
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
101
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
102
+ user_path = dirname
103
+ have_gfpgan = True
104
+ gfpgan_constructor = GFPGANer
105
+
106
+ class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
107
+ def name(self):
108
+ return "GFPGAN"
109
+
110
+ def restore(self, np_image):
111
+ return gfpgan_fix_faces(np_image)
112
+
113
+ shared.face_restorers.append(FaceRestorerGFPGAN())
114
+ except Exception:
115
+ print("Error setting up GFPGAN:", file=sys.stderr)
116
+ print(traceback.format_exc(), file=sys.stderr)
sd/stable-diffusion-webui/modules/hashes.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import os.path
4
+
5
+ import filelock
6
+
7
+ from modules import shared
8
+ from modules.paths import data_path
9
+
10
+
11
+ cache_filename = os.path.join(data_path, "cache.json")
12
+ cache_data = None
13
+
14
+
15
+ def dump_cache():
16
+ with filelock.FileLock(cache_filename+".lock"):
17
+ with open(cache_filename, "w", encoding="utf8") as file:
18
+ json.dump(cache_data, file, indent=4)
19
+
20
+
21
+ def cache(subsection):
22
+ global cache_data
23
+
24
+ if cache_data is None:
25
+ with filelock.FileLock(cache_filename+".lock"):
26
+ if not os.path.isfile(cache_filename):
27
+ cache_data = {}
28
+ else:
29
+ with open(cache_filename, "r", encoding="utf8") as file:
30
+ cache_data = json.load(file)
31
+
32
+ s = cache_data.get(subsection, {})
33
+ cache_data[subsection] = s
34
+
35
+ return s
36
+
37
+
38
+ def calculate_sha256(filename):
39
+ hash_sha256 = hashlib.sha256()
40
+ blksize = 1024 * 1024
41
+
42
+ with open(filename, "rb") as f:
43
+ for chunk in iter(lambda: f.read(blksize), b""):
44
+ hash_sha256.update(chunk)
45
+
46
+ return hash_sha256.hexdigest()
47
+
48
+
49
+ def sha256_from_cache(filename, title):
50
+ hashes = cache("hashes")
51
+ ondisk_mtime = os.path.getmtime(filename)
52
+
53
+ if title not in hashes:
54
+ return None
55
+
56
+ cached_sha256 = hashes[title].get("sha256", None)
57
+ cached_mtime = hashes[title].get("mtime", 0)
58
+
59
+ if ondisk_mtime > cached_mtime or cached_sha256 is None:
60
+ return None
61
+
62
+ return cached_sha256
63
+
64
+
65
+ def sha256(filename, title):
66
+ hashes = cache("hashes")
67
+
68
+ sha256_value = sha256_from_cache(filename, title)
69
+ if sha256_value is not None:
70
+ return sha256_value
71
+
72
+ if shared.cmd_opts.no_hashing:
73
+ return None
74
+
75
+ print(f"Calculating sha256 for {filename}: ", end='')
76
+ sha256_value = calculate_sha256(filename)
77
+ print(f"{sha256_value}")
78
+
79
+ hashes[title] = {
80
+ "mtime": os.path.getmtime(filename),
81
+ "sha256": sha256_value,
82
+ }
83
+
84
+ dump_cache()
85
+
86
+ return sha256_value
87
+
88
+
89
+
90
+
91
+
sd/stable-diffusion-webui/modules/images.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import sys
3
+ import traceback
4
+
5
+ import pytz
6
+ import io
7
+ import math
8
+ import os
9
+ from collections import namedtuple
10
+ import re
11
+
12
+ import numpy as np
13
+ import piexif
14
+ import piexif.helper
15
+ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
16
+ from fonts.ttf import Roboto
17
+ import string
18
+ import json
19
+ import hashlib
20
+
21
+ from modules import sd_samplers, shared, script_callbacks, errors
22
+ from modules.shared import opts, cmd_opts
23
+
24
+ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
25
+
26
+
27
+ def image_grid(imgs, batch_size=1, rows=None):
28
+ if rows is None:
29
+ if opts.n_rows > 0:
30
+ rows = opts.n_rows
31
+ elif opts.n_rows == 0:
32
+ rows = batch_size
33
+ elif opts.grid_prevent_empty_spots:
34
+ rows = math.floor(math.sqrt(len(imgs)))
35
+ while len(imgs) % rows != 0:
36
+ rows -= 1
37
+ else:
38
+ rows = math.sqrt(len(imgs))
39
+ rows = round(rows)
40
+ if rows > len(imgs):
41
+ rows = len(imgs)
42
+
43
+ cols = math.ceil(len(imgs) / rows)
44
+
45
+ params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
46
+ script_callbacks.image_grid_callback(params)
47
+
48
+ w, h = imgs[0].size
49
+ grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
50
+
51
+ for i, img in enumerate(params.imgs):
52
+ grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
53
+
54
+ return grid
55
+
56
+
57
+ Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
58
+
59
+
60
+ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
61
+ w = image.width
62
+ h = image.height
63
+
64
+ non_overlap_width = tile_w - overlap
65
+ non_overlap_height = tile_h - overlap
66
+
67
+ cols = math.ceil((w - overlap) / non_overlap_width)
68
+ rows = math.ceil((h - overlap) / non_overlap_height)
69
+
70
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
71
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
72
+
73
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
74
+ for row in range(rows):
75
+ row_images = []
76
+
77
+ y = int(row * dy)
78
+
79
+ if y + tile_h >= h:
80
+ y = h - tile_h
81
+
82
+ for col in range(cols):
83
+ x = int(col * dx)
84
+
85
+ if x + tile_w >= w:
86
+ x = w - tile_w
87
+
88
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
89
+
90
+ row_images.append([x, tile_w, tile])
91
+
92
+ grid.tiles.append([y, tile_h, row_images])
93
+
94
+ return grid
95
+
96
+
97
+ def combine_grid(grid):
98
+ def make_mask_image(r):
99
+ r = r * 255 / grid.overlap
100
+ r = r.astype(np.uint8)
101
+ return Image.fromarray(r, 'L')
102
+
103
+ mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
104
+ mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
105
+
106
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
107
+ for y, h, row in grid.tiles:
108
+ combined_row = Image.new("RGB", (grid.image_w, h))
109
+ for x, w, tile in row:
110
+ if x == 0:
111
+ combined_row.paste(tile, (0, 0))
112
+ continue
113
+
114
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
115
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
116
+
117
+ if y == 0:
118
+ combined_image.paste(combined_row, (0, 0))
119
+ continue
120
+
121
+ combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
122
+ combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
123
+
124
+ return combined_image
125
+
126
+
127
+ class GridAnnotation:
128
+ def __init__(self, text='', is_active=True):
129
+ self.text = text
130
+ self.is_active = is_active
131
+ self.size = None
132
+
133
+
134
+ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
135
+ def wrap(drawing, text, font, line_length):
136
+ lines = ['']
137
+ for word in text.split():
138
+ line = f'{lines[-1]} {word}'.strip()
139
+ if drawing.textlength(line, font=font) <= line_length:
140
+ lines[-1] = line
141
+ else:
142
+ lines.append(word)
143
+ return lines
144
+
145
+ def get_font(fontsize):
146
+ try:
147
+ return ImageFont.truetype(opts.font or Roboto, fontsize)
148
+ except Exception:
149
+ return ImageFont.truetype(Roboto, fontsize)
150
+
151
+ def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
152
+ for i, line in enumerate(lines):
153
+ fnt = initial_fnt
154
+ fontsize = initial_fontsize
155
+ while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
156
+ fontsize -= 1
157
+ fnt = get_font(fontsize)
158
+ drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
159
+
160
+ if not line.is_active:
161
+ drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
162
+
163
+ draw_y += line.size[1] + line_spacing
164
+
165
+ fontsize = (width + height) // 25
166
+ line_spacing = fontsize // 2
167
+
168
+ fnt = get_font(fontsize)
169
+
170
+ color_active = (0, 0, 0)
171
+ color_inactive = (153, 153, 153)
172
+
173
+ pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
174
+
175
+ cols = im.width // width
176
+ rows = im.height // height
177
+
178
+ assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
179
+ assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
180
+
181
+ calc_img = Image.new("RGB", (1, 1), "white")
182
+ calc_d = ImageDraw.Draw(calc_img)
183
+
184
+ for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
185
+ items = [] + texts
186
+ texts.clear()
187
+
188
+ for line in items:
189
+ wrapped = wrap(calc_d, line.text, fnt, allowed_width)
190
+ texts += [GridAnnotation(x, line.is_active) for x in wrapped]
191
+
192
+ for line in texts:
193
+ bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
194
+ line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
195
+ line.allowed_width = allowed_width
196
+
197
+ hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
198
+ ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
199
+
200
+ pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
201
+
202
+ result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
203
+
204
+ for row in range(rows):
205
+ for col in range(cols):
206
+ cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
207
+ result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
208
+
209
+ d = ImageDraw.Draw(result)
210
+
211
+ for col in range(cols):
212
+ x = pad_left + (width + margin) * col + width / 2
213
+ y = pad_top / 2 - hor_text_heights[col] / 2
214
+
215
+ draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
216
+
217
+ for row in range(rows):
218
+ x = pad_left / 2
219
+ y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
220
+
221
+ draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
222
+
223
+ return result
224
+
225
+
226
+ def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
227
+ prompts = all_prompts[1:]
228
+ boundary = math.ceil(len(prompts) / 2)
229
+
230
+ prompts_horiz = prompts[:boundary]
231
+ prompts_vert = prompts[boundary:]
232
+
233
+ hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
234
+ ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
235
+
236
+ return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
237
+
238
+
239
+ def resize_image(resize_mode, im, width, height, upscaler_name=None):
240
+ """
241
+ Resizes an image with the specified resize_mode, width, and height.
242
+
243
+ Args:
244
+ resize_mode: The mode to use when resizing the image.
245
+ 0: Resize the image to the specified width and height.
246
+ 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
247
+ 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
248
+ im: The image to resize.
249
+ width: The width to resize the image to.
250
+ height: The height to resize the image to.
251
+ upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
252
+ """
253
+
254
+ upscaler_name = upscaler_name or opts.upscaler_for_img2img
255
+
256
+ def resize(im, w, h):
257
+ if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
258
+ return im.resize((w, h), resample=LANCZOS)
259
+
260
+ scale = max(w / im.width, h / im.height)
261
+
262
+ if scale > 1.0:
263
+ upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
264
+ assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
265
+
266
+ upscaler = upscalers[0]
267
+ im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
268
+
269
+ if im.width != w or im.height != h:
270
+ im = im.resize((w, h), resample=LANCZOS)
271
+
272
+ return im
273
+
274
+ if resize_mode == 0:
275
+ res = resize(im, width, height)
276
+
277
+ elif resize_mode == 1:
278
+ ratio = width / height
279
+ src_ratio = im.width / im.height
280
+
281
+ src_w = width if ratio > src_ratio else im.width * height // im.height
282
+ src_h = height if ratio <= src_ratio else im.height * width // im.width
283
+
284
+ resized = resize(im, src_w, src_h)
285
+ res = Image.new("RGB", (width, height))
286
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
287
+
288
+ else:
289
+ ratio = width / height
290
+ src_ratio = im.width / im.height
291
+
292
+ src_w = width if ratio < src_ratio else im.width * height // im.height
293
+ src_h = height if ratio >= src_ratio else im.height * width // im.width
294
+
295
+ resized = resize(im, src_w, src_h)
296
+ res = Image.new("RGB", (width, height))
297
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
298
+
299
+ if ratio < src_ratio:
300
+ fill_height = height // 2 - src_h // 2
301
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
302
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
303
+ elif ratio > src_ratio:
304
+ fill_width = width // 2 - src_w // 2
305
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
306
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
307
+
308
+ return res
309
+
310
+
311
+ invalid_filename_chars = '<>:"/\\|?*\n'
312
+ invalid_filename_prefix = ' '
313
+ invalid_filename_postfix = ' .'
314
+ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
315
+ re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
316
+ re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
317
+ max_filename_part_length = 128
318
+
319
+
320
+ def sanitize_filename_part(text, replace_spaces=True):
321
+ if text is None:
322
+ return None
323
+
324
+ if replace_spaces:
325
+ text = text.replace(' ', '_')
326
+
327
+ text = text.translate({ord(x): '_' for x in invalid_filename_chars})
328
+ text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
329
+ text = text.rstrip(invalid_filename_postfix)
330
+ return text
331
+
332
+
333
+ class FilenameGenerator:
334
+ replacements = {
335
+ 'seed': lambda self: self.seed if self.seed is not None else '',
336
+ 'steps': lambda self: self.p and self.p.steps,
337
+ 'cfg': lambda self: self.p and self.p.cfg_scale,
338
+ 'width': lambda self: self.image.width,
339
+ 'height': lambda self: self.image.height,
340
+ 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
341
+ 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
342
+ 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
343
+ 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
344
+ 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
345
+ 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
346
+ 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
347
+ 'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
348
+ 'prompt': lambda self: sanitize_filename_part(self.prompt),
349
+ 'prompt_no_styles': lambda self: self.prompt_no_style(),
350
+ 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
351
+ 'prompt_words': lambda self: self.prompt_words(),
352
+ }
353
+ default_time_format = '%Y%m%d%H%M%S'
354
+
355
+ def __init__(self, p, seed, prompt, image):
356
+ self.p = p
357
+ self.seed = seed
358
+ self.prompt = prompt
359
+ self.image = image
360
+
361
+ def prompt_no_style(self):
362
+ if self.p is None or self.prompt is None:
363
+ return None
364
+
365
+ prompt_no_style = self.prompt
366
+ for style in shared.prompt_styles.get_style_prompts(self.p.styles):
367
+ if len(style) > 0:
368
+ for part in style.split("{prompt}"):
369
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
370
+
371
+ prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
372
+
373
+ return sanitize_filename_part(prompt_no_style, replace_spaces=False)
374
+
375
+ def prompt_words(self):
376
+ words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
377
+ if len(words) == 0:
378
+ words = ["empty"]
379
+ return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
380
+
381
+ def datetime(self, *args):
382
+ time_datetime = datetime.datetime.now()
383
+
384
+ time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
385
+ try:
386
+ time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
387
+ except pytz.exceptions.UnknownTimeZoneError as _:
388
+ time_zone = None
389
+
390
+ time_zone_time = time_datetime.astimezone(time_zone)
391
+ try:
392
+ formatted_time = time_zone_time.strftime(time_format)
393
+ except (ValueError, TypeError) as _:
394
+ formatted_time = time_zone_time.strftime(self.default_time_format)
395
+
396
+ return sanitize_filename_part(formatted_time, replace_spaces=False)
397
+
398
+ def apply(self, x):
399
+ res = ''
400
+
401
+ for m in re_pattern.finditer(x):
402
+ text, pattern = m.groups()
403
+ res += text
404
+
405
+ if pattern is None:
406
+ continue
407
+
408
+ pattern_args = []
409
+ while True:
410
+ m = re_pattern_arg.match(pattern)
411
+ if m is None:
412
+ break
413
+
414
+ pattern, arg = m.groups()
415
+ pattern_args.insert(0, arg)
416
+
417
+ fun = self.replacements.get(pattern.lower())
418
+ if fun is not None:
419
+ try:
420
+ replacement = fun(self, *pattern_args)
421
+ except Exception:
422
+ replacement = None
423
+ print(f"Error adding [{pattern}] to filename", file=sys.stderr)
424
+ print(traceback.format_exc(), file=sys.stderr)
425
+
426
+ if replacement is not None:
427
+ res += str(replacement)
428
+ continue
429
+
430
+ res += f'[{pattern}]'
431
+
432
+ return res
433
+
434
+
435
+ def get_next_sequence_number(path, basename):
436
+ """
437
+ Determines and returns the next sequence number to use when saving an image in the specified directory.
438
+
439
+ The sequence starts at 0.
440
+ """
441
+ result = -1
442
+ if basename != '':
443
+ basename = basename + "-"
444
+
445
+ prefix_length = len(basename)
446
+ for p in os.listdir(path):
447
+ if p.startswith(basename):
448
+ l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
449
+ try:
450
+ result = max(int(l[0]), result)
451
+ except ValueError:
452
+ pass
453
+
454
+ return result + 1
455
+
456
+
457
+ def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
458
+ """Save an image.
459
+
460
+ Args:
461
+ image (`PIL.Image`):
462
+ The image to be saved.
463
+ path (`str`):
464
+ The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
465
+ basename (`str`):
466
+ The base filename which will be applied to `filename pattern`.
467
+ seed, prompt, short_filename,
468
+ extension (`str`):
469
+ Image file extension, default is `png`.
470
+ pngsectionname (`str`):
471
+ Specify the name of the section which `info` will be saved in.
472
+ info (`str` or `PngImagePlugin.iTXt`):
473
+ PNG info chunks.
474
+ existing_info (`dict`):
475
+ Additional PNG info. `existing_info == {pngsectionname: info, ...}`
476
+ no_prompt:
477
+ TODO I don't know its meaning.
478
+ p (`StableDiffusionProcessing`)
479
+ forced_filename (`str`):
480
+ If specified, `basename` and filename pattern will be ignored.
481
+ save_to_dirs (bool):
482
+ If true, the image will be saved into a subdirectory of `path`.
483
+
484
+ Returns: (fullfn, txt_fullfn)
485
+ fullfn (`str`):
486
+ The full path of the saved imaged.
487
+ txt_fullfn (`str` or None):
488
+ If a text file is saved for this image, this will be its full path. Otherwise None.
489
+ """
490
+ namegen = FilenameGenerator(p, seed, prompt, image)
491
+
492
+ if save_to_dirs is None:
493
+ save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
494
+
495
+ if save_to_dirs:
496
+ dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
497
+ path = os.path.join(path, dirname)
498
+
499
+ os.makedirs(path, exist_ok=True)
500
+
501
+ if forced_filename is None:
502
+ if short_filename or seed is None:
503
+ file_decoration = ""
504
+ elif opts.save_to_dirs:
505
+ file_decoration = opts.samples_filename_pattern or "[seed]"
506
+ else:
507
+ file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
508
+
509
+ add_number = opts.save_images_add_number or file_decoration == ''
510
+
511
+ if file_decoration != "" and add_number:
512
+ file_decoration = "-" + file_decoration
513
+
514
+ file_decoration = namegen.apply(file_decoration) + suffix
515
+
516
+ if add_number:
517
+ basecount = get_next_sequence_number(path, basename)
518
+ fullfn = None
519
+ for i in range(500):
520
+ fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
521
+ fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
522
+ if not os.path.exists(fullfn):
523
+ break
524
+ else:
525
+ fullfn = os.path.join(path, f"{file_decoration}.{extension}")
526
+ else:
527
+ fullfn = os.path.join(path, f"{forced_filename}.{extension}")
528
+
529
+ pnginfo = existing_info or {}
530
+ if info is not None:
531
+ pnginfo[pnginfo_section_name] = info
532
+
533
+ params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
534
+ script_callbacks.before_image_saved_callback(params)
535
+
536
+ image = params.image
537
+ fullfn = params.filename
538
+ info = params.pnginfo.get(pnginfo_section_name, None)
539
+
540
+ def _atomically_save_image(image_to_save, filename_without_extension, extension):
541
+ # save image with .tmp extension to avoid race condition when another process detects new image in the directory
542
+ temp_file_path = filename_without_extension + ".tmp"
543
+ image_format = Image.registered_extensions()[extension]
544
+
545
+ if extension.lower() == '.png':
546
+ pnginfo_data = PngImagePlugin.PngInfo()
547
+ if opts.enable_pnginfo:
548
+ for k, v in params.pnginfo.items():
549
+ pnginfo_data.add_text(k, str(v))
550
+
551
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
552
+
553
+ elif extension.lower() in (".jpg", ".jpeg", ".webp"):
554
+ if image_to_save.mode == 'RGBA':
555
+ image_to_save = image_to_save.convert("RGB")
556
+ elif image_to_save.mode == 'I;16':
557
+ image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
558
+
559
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
560
+
561
+ if opts.enable_pnginfo and info is not None:
562
+ exif_bytes = piexif.dump({
563
+ "Exif": {
564
+ piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
565
+ },
566
+ })
567
+
568
+ piexif.insert(exif_bytes, temp_file_path)
569
+ else:
570
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
571
+
572
+ # atomically rename the file with correct extension
573
+ os.replace(temp_file_path, filename_without_extension + extension)
574
+
575
+ fullfn_without_extension, extension = os.path.splitext(params.filename)
576
+ _atomically_save_image(image, fullfn_without_extension, extension)
577
+
578
+ image.already_saved_as = fullfn
579
+
580
+ oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
581
+ if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
582
+ ratio = image.width / image.height
583
+
584
+ if oversize and ratio > 1:
585
+ image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
586
+ elif oversize:
587
+ image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
588
+
589
+ try:
590
+ _atomically_save_image(image, fullfn_without_extension, ".jpg")
591
+ except Exception as e:
592
+ errors.display(e, "saving image as downscaled JPG")
593
+
594
+ if opts.save_txt and info is not None:
595
+ txt_fullfn = f"{fullfn_without_extension}.txt"
596
+ with open(txt_fullfn, "w", encoding="utf8") as file:
597
+ file.write(info + "\n")
598
+ else:
599
+ txt_fullfn = None
600
+
601
+ script_callbacks.image_saved_callback(params)
602
+
603
+ return fullfn, txt_fullfn
604
+
605
+
606
+ def read_info_from_image(image):
607
+ items = image.info or {}
608
+
609
+ geninfo = items.pop('parameters', None)
610
+
611
+ if "exif" in items:
612
+ exif = piexif.load(items["exif"])
613
+ exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
614
+ try:
615
+ exif_comment = piexif.helper.UserComment.load(exif_comment)
616
+ except ValueError:
617
+ exif_comment = exif_comment.decode('utf8', errors="ignore")
618
+
619
+ if exif_comment:
620
+ items['exif comment'] = exif_comment
621
+ geninfo = exif_comment
622
+
623
+ for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
624
+ 'loop', 'background', 'timestamp', 'duration']:
625
+ items.pop(field, None)
626
+
627
+ if items.get("Software", None) == "NovelAI":
628
+ try:
629
+ json_info = json.loads(items["Comment"])
630
+ sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
631
+
632
+ geninfo = f"""{items["Description"]}
633
+ Negative prompt: {json_info["uc"]}
634
+ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
635
+ except Exception:
636
+ print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
637
+ print(traceback.format_exc(), file=sys.stderr)
638
+
639
+ return geninfo, items
640
+
641
+
642
+ def image_data(data):
643
+ try:
644
+ image = Image.open(io.BytesIO(data))
645
+ textinfo, _ = read_info_from_image(image)
646
+ return textinfo, None
647
+ except Exception:
648
+ pass
649
+
650
+ try:
651
+ text = data.decode('utf8')
652
+ assert len(text) < 10000
653
+ return text, None
654
+
655
+ except Exception:
656
+ pass
657
+
658
+ return '', None
659
+
660
+
661
+ def flatten(img, bgcolor):
662
+ """replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
663
+
664
+ if img.mode == "RGBA":
665
+ background = Image.new('RGBA', img.size, bgcolor)
666
+ background.paste(img, mask=img)
667
+ img = background
668
+
669
+ return img.convert('RGB')
sd/stable-diffusion-webui/modules/img2img.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import sys
4
+ import traceback
5
+
6
+ import numpy as np
7
+ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
8
+
9
+ from modules import devices, sd_samplers
10
+ from modules.generation_parameters_copypaste import create_override_settings_dict
11
+ from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
12
+ from modules.shared import opts, state
13
+ import modules.shared as shared
14
+ import modules.processing as processing
15
+ from modules.ui import plaintext_to_html
16
+ import modules.images as images
17
+ import modules.scripts
18
+
19
+
20
+ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
21
+ processing.fix_seed(p)
22
+
23
+ images = shared.listfiles(input_dir)
24
+
25
+ is_inpaint_batch = False
26
+ if inpaint_mask_dir:
27
+ inpaint_masks = shared.listfiles(inpaint_mask_dir)
28
+ is_inpaint_batch = len(inpaint_masks) > 0
29
+ if is_inpaint_batch:
30
+ print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
31
+
32
+ print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
33
+
34
+ save_normally = output_dir == ''
35
+
36
+ p.do_not_save_grid = True
37
+ p.do_not_save_samples = not save_normally
38
+
39
+ state.job_count = len(images) * p.n_iter
40
+
41
+ for i, image in enumerate(images):
42
+ state.job = f"{i+1} out of {len(images)}"
43
+ if state.skipped:
44
+ state.skipped = False
45
+
46
+ if state.interrupted:
47
+ break
48
+
49
+ img = Image.open(image)
50
+ # Use the EXIF orientation of photos taken by smartphones.
51
+ img = ImageOps.exif_transpose(img)
52
+ p.init_images = [img] * p.batch_size
53
+
54
+ if is_inpaint_batch:
55
+ # try to find corresponding mask for an image using simple filename matching
56
+ mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
57
+ # if not found use first one ("same mask for all images" use-case)
58
+ if not mask_image_path in inpaint_masks:
59
+ mask_image_path = inpaint_masks[0]
60
+ mask_image = Image.open(mask_image_path)
61
+ p.image_mask = mask_image
62
+
63
+ proc = modules.scripts.scripts_img2img.run(p, *args)
64
+ if proc is None:
65
+ proc = process_images(p)
66
+
67
+ for n, processed_image in enumerate(proc.images):
68
+ filename = os.path.basename(image)
69
+
70
+ if n > 0:
71
+ left, right = os.path.splitext(filename)
72
+ filename = f"{left}-{n}{right}"
73
+
74
+ if not save_normally:
75
+ os.makedirs(output_dir, exist_ok=True)
76
+ if processed_image.mode == 'RGBA':
77
+ processed_image = processed_image.convert("RGB")
78
+ processed_image.save(os.path.join(output_dir, filename))
79
+
80
+
81
+ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
82
+ override_settings = create_override_settings_dict(override_settings_texts)
83
+
84
+ is_batch = mode == 5
85
+
86
+ if mode == 0: # img2img
87
+ image = init_img.convert("RGB")
88
+ mask = None
89
+ elif mode == 1: # img2img sketch
90
+ image = sketch.convert("RGB")
91
+ mask = None
92
+ elif mode == 2: # inpaint
93
+ image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
94
+ alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
95
+ mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
96
+ image = image.convert("RGB")
97
+ elif mode == 3: # inpaint sketch
98
+ image = inpaint_color_sketch
99
+ orig = inpaint_color_sketch_orig or inpaint_color_sketch
100
+ pred = np.any(np.array(image) != np.array(orig), axis=-1)
101
+ mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
102
+ mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
103
+ blur = ImageFilter.GaussianBlur(mask_blur)
104
+ image = Image.composite(image.filter(blur), orig, mask.filter(blur))
105
+ image = image.convert("RGB")
106
+ elif mode == 4: # inpaint upload mask
107
+ image = init_img_inpaint
108
+ mask = init_mask_inpaint
109
+ else:
110
+ image = None
111
+ mask = None
112
+
113
+ # Use the EXIF orientation of photos taken by smartphones.
114
+ if image is not None:
115
+ image = ImageOps.exif_transpose(image)
116
+
117
+ assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
118
+
119
+ p = StableDiffusionProcessingImg2Img(
120
+ sd_model=shared.sd_model,
121
+ outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
122
+ outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
123
+ prompt=prompt,
124
+ negative_prompt=negative_prompt,
125
+ styles=prompt_styles,
126
+ seed=seed,
127
+ subseed=subseed,
128
+ subseed_strength=subseed_strength,
129
+ seed_resize_from_h=seed_resize_from_h,
130
+ seed_resize_from_w=seed_resize_from_w,
131
+ seed_enable_extras=seed_enable_extras,
132
+ sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
133
+ batch_size=batch_size,
134
+ n_iter=n_iter,
135
+ steps=steps,
136
+ cfg_scale=cfg_scale,
137
+ width=width,
138
+ height=height,
139
+ restore_faces=restore_faces,
140
+ tiling=tiling,
141
+ init_images=[image],
142
+ mask=mask,
143
+ mask_blur=mask_blur,
144
+ inpainting_fill=inpainting_fill,
145
+ resize_mode=resize_mode,
146
+ denoising_strength=denoising_strength,
147
+ image_cfg_scale=image_cfg_scale,
148
+ inpaint_full_res=inpaint_full_res,
149
+ inpaint_full_res_padding=inpaint_full_res_padding,
150
+ inpainting_mask_invert=inpainting_mask_invert,
151
+ override_settings=override_settings,
152
+ )
153
+
154
+ p.scripts = modules.scripts.scripts_txt2img
155
+ p.script_args = args
156
+
157
+ if shared.cmd_opts.enable_console_prompts:
158
+ print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
159
+
160
+ p.extra_generation_params["Mask blur"] = mask_blur
161
+
162
+ if is_batch:
163
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
164
+
165
+ process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
166
+
167
+ processed = Processed(p, [], p.seed, "")
168
+ else:
169
+ processed = modules.scripts.scripts_img2img.run(p, *args)
170
+ if processed is None:
171
+ processed = process_images(p)
172
+
173
+ p.close()
174
+
175
+ shared.total_tqdm.clear()
176
+
177
+ generation_info_js = processed.js()
178
+ if opts.samples_log_stdout:
179
+ print(generation_info_js)
180
+
181
+ if opts.do_not_show_images:
182
+ processed.images = []
183
+
184
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
sd/stable-diffusion-webui/modules/import_hook.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import sys
2
+
3
+ # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
4
+ if "--xformers" not in "".join(sys.argv):
5
+ sys.modules["xformers"] = None
sd/stable-diffusion-webui/modules/interrogate.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+ from collections import namedtuple
5
+ from pathlib import Path
6
+ import re
7
+
8
+ import torch
9
+ import torch.hub
10
+
11
+ from torchvision import transforms
12
+ from torchvision.transforms.functional import InterpolationMode
13
+
14
+ import modules.shared as shared
15
+ from modules import devices, paths, shared, lowvram, modelloader, errors
16
+
17
+ blip_image_eval_size = 384
18
+ clip_model_name = 'ViT-L/14'
19
+
20
+ Category = namedtuple("Category", ["name", "topn", "items"])
21
+
22
+ re_topn = re.compile(r"\.top(\d+)\.")
23
+
24
+ def category_types():
25
+ return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
26
+
27
+
28
+ def download_default_clip_interrogate_categories(content_dir):
29
+ print("Downloading CLIP categories...")
30
+
31
+ tmpdir = content_dir + "_tmp"
32
+ category_types = ["artists", "flavors", "mediums", "movements"]
33
+
34
+ try:
35
+ os.makedirs(tmpdir)
36
+ for category_type in category_types:
37
+ torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
38
+ os.rename(tmpdir, content_dir)
39
+
40
+ except Exception as e:
41
+ errors.display(e, "downloading default CLIP interrogate categories")
42
+ finally:
43
+ if os.path.exists(tmpdir):
44
+ os.remove(tmpdir)
45
+
46
+
47
+ class InterrogateModels:
48
+ blip_model = None
49
+ clip_model = None
50
+ clip_preprocess = None
51
+ dtype = None
52
+ running_on_cpu = None
53
+
54
+ def __init__(self, content_dir):
55
+ self.loaded_categories = None
56
+ self.skip_categories = []
57
+ self.content_dir = content_dir
58
+ self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
59
+
60
+ def categories(self):
61
+ if not os.path.exists(self.content_dir):
62
+ download_default_clip_interrogate_categories(self.content_dir)
63
+
64
+ if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
65
+ return self.loaded_categories
66
+
67
+ self.loaded_categories = []
68
+
69
+ if os.path.exists(self.content_dir):
70
+ self.skip_categories = shared.opts.interrogate_clip_skip_categories
71
+ category_types = []
72
+ for filename in Path(self.content_dir).glob('*.txt'):
73
+ category_types.append(filename.stem)
74
+ if filename.stem in self.skip_categories:
75
+ continue
76
+ m = re_topn.search(filename.stem)
77
+ topn = 1 if m is None else int(m.group(1))
78
+ with open(filename, "r", encoding="utf8") as file:
79
+ lines = [x.strip() for x in file.readlines()]
80
+
81
+ self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
82
+
83
+ return self.loaded_categories
84
+
85
+ def create_fake_fairscale(self):
86
+ class FakeFairscale:
87
+ def checkpoint_wrapper(self):
88
+ pass
89
+
90
+ sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
91
+
92
+ def load_blip_model(self):
93
+ self.create_fake_fairscale()
94
+ import models.blip
95
+
96
+ files = modelloader.load_models(
97
+ model_path=os.path.join(paths.models_path, "BLIP"),
98
+ model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
99
+ ext_filter=[".pth"],
100
+ download_name='model_base_caption_capfilt_large.pth',
101
+ )
102
+
103
+ blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
104
+ blip_model.eval()
105
+
106
+ return blip_model
107
+
108
+ def load_clip_model(self):
109
+ import clip
110
+
111
+ if self.running_on_cpu:
112
+ model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
113
+ else:
114
+ model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
115
+
116
+ model.eval()
117
+ model = model.to(devices.device_interrogate)
118
+
119
+ return model, preprocess
120
+
121
+ def load(self):
122
+ if self.blip_model is None:
123
+ self.blip_model = self.load_blip_model()
124
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
125
+ self.blip_model = self.blip_model.half()
126
+
127
+ self.blip_model = self.blip_model.to(devices.device_interrogate)
128
+
129
+ if self.clip_model is None:
130
+ self.clip_model, self.clip_preprocess = self.load_clip_model()
131
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
132
+ self.clip_model = self.clip_model.half()
133
+
134
+ self.clip_model = self.clip_model.to(devices.device_interrogate)
135
+
136
+ self.dtype = next(self.clip_model.parameters()).dtype
137
+
138
+ def send_clip_to_ram(self):
139
+ if not shared.opts.interrogate_keep_models_in_memory:
140
+ if self.clip_model is not None:
141
+ self.clip_model = self.clip_model.to(devices.cpu)
142
+
143
+ def send_blip_to_ram(self):
144
+ if not shared.opts.interrogate_keep_models_in_memory:
145
+ if self.blip_model is not None:
146
+ self.blip_model = self.blip_model.to(devices.cpu)
147
+
148
+ def unload(self):
149
+ self.send_clip_to_ram()
150
+ self.send_blip_to_ram()
151
+
152
+ devices.torch_gc()
153
+
154
+ def rank(self, image_features, text_array, top_count=1):
155
+ import clip
156
+
157
+ devices.torch_gc()
158
+
159
+ if shared.opts.interrogate_clip_dict_limit != 0:
160
+ text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
161
+
162
+ top_count = min(top_count, len(text_array))
163
+ text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
164
+ text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
165
+ text_features /= text_features.norm(dim=-1, keepdim=True)
166
+
167
+ similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
168
+ for i in range(image_features.shape[0]):
169
+ similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
170
+ similarity /= image_features.shape[0]
171
+
172
+ top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
173
+ return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
174
+
175
+ def generate_caption(self, pil_image):
176
+ gpu_image = transforms.Compose([
177
+ transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
178
+ transforms.ToTensor(),
179
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
180
+ ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
181
+
182
+ with torch.no_grad():
183
+ caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
184
+
185
+ return caption[0]
186
+
187
+ def interrogate(self, pil_image):
188
+ res = ""
189
+ shared.state.begin()
190
+ shared.state.job = 'interrogate'
191
+ try:
192
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
193
+ lowvram.send_everything_to_cpu()
194
+ devices.torch_gc()
195
+
196
+ self.load()
197
+
198
+ caption = self.generate_caption(pil_image)
199
+ self.send_blip_to_ram()
200
+ devices.torch_gc()
201
+
202
+ res = caption
203
+
204
+ clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
205
+
206
+ with torch.no_grad(), devices.autocast():
207
+ image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
208
+
209
+ image_features /= image_features.norm(dim=-1, keepdim=True)
210
+
211
+ for name, topn, items in self.categories():
212
+ matches = self.rank(image_features, items, top_count=topn)
213
+ for match, score in matches:
214
+ if shared.opts.interrogate_return_ranks:
215
+ res += f", ({match}:{score/100:.3f})"
216
+ else:
217
+ res += ", " + match
218
+
219
+ except Exception:
220
+ print("Error interrogating", file=sys.stderr)
221
+ print(traceback.format_exc(), file=sys.stderr)
222
+ res += "<error>"
223
+
224
+ self.unload()
225
+ shared.state.end()
226
+
227
+ return res
sd/stable-diffusion-webui/modules/localization.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import traceback
5
+
6
+
7
+ localizations = {}
8
+
9
+
10
+ def list_localizations(dirname):
11
+ localizations.clear()
12
+
13
+ for file in os.listdir(dirname):
14
+ fn, ext = os.path.splitext(file)
15
+ if ext.lower() != ".json":
16
+ continue
17
+
18
+ localizations[fn] = os.path.join(dirname, file)
19
+
20
+ from modules import scripts
21
+ for file in scripts.list_scripts("localizations", ".json"):
22
+ fn, ext = os.path.splitext(file.filename)
23
+ localizations[fn] = file.path
24
+
25
+
26
+ def localization_js(current_localization_name):
27
+ fn = localizations.get(current_localization_name, None)
28
+ data = {}
29
+ if fn is not None:
30
+ try:
31
+ with open(fn, "r", encoding="utf8") as file:
32
+ data = json.load(file)
33
+ except Exception:
34
+ print(f"Error loading localization from {fn}:", file=sys.stderr)
35
+ print(traceback.format_exc(), file=sys.stderr)
36
+
37
+ return f"var localization = {json.dumps(data)}\n"
sd/stable-diffusion-webui/modules/lowvram.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from modules import devices
3
+
4
+ module_in_gpu = None
5
+ cpu = torch.device("cpu")
6
+
7
+
8
+ def send_everything_to_cpu():
9
+ global module_in_gpu
10
+
11
+ if module_in_gpu is not None:
12
+ module_in_gpu.to(cpu)
13
+
14
+ module_in_gpu = None
15
+
16
+
17
+ def setup_for_low_vram(sd_model, use_medvram):
18
+ parents = {}
19
+
20
+ def send_me_to_gpu(module, _):
21
+ """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
22
+ we add this as forward_pre_hook to a lot of modules and this way all but one of them will
23
+ be in CPU
24
+ """
25
+ global module_in_gpu
26
+
27
+ module = parents.get(module, module)
28
+
29
+ if module_in_gpu == module:
30
+ return
31
+
32
+ if module_in_gpu is not None:
33
+ module_in_gpu.to(cpu)
34
+
35
+ module.to(devices.device)
36
+ module_in_gpu = module
37
+
38
+ # see below for register_forward_pre_hook;
39
+ # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
40
+ # useless here, and we just replace those methods
41
+
42
+ first_stage_model = sd_model.first_stage_model
43
+ first_stage_model_encode = sd_model.first_stage_model.encode
44
+ first_stage_model_decode = sd_model.first_stage_model.decode
45
+
46
+ def first_stage_model_encode_wrap(x):
47
+ send_me_to_gpu(first_stage_model, None)
48
+ return first_stage_model_encode(x)
49
+
50
+ def first_stage_model_decode_wrap(z):
51
+ send_me_to_gpu(first_stage_model, None)
52
+ return first_stage_model_decode(z)
53
+
54
+ # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
55
+ if hasattr(sd_model.cond_stage_model, 'model'):
56
+ sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
57
+
58
+ # remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
59
+ # send the model to GPU. Then put modules back. the modules will be in CPU.
60
+ stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
61
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
62
+ sd_model.to(devices.device)
63
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
64
+
65
+ # register hooks for those the first three models
66
+ sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
67
+ sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
68
+ sd_model.first_stage_model.encode = first_stage_model_encode_wrap
69
+ sd_model.first_stage_model.decode = first_stage_model_decode_wrap
70
+ if sd_model.depth_model:
71
+ sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
72
+ parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
73
+
74
+ if hasattr(sd_model.cond_stage_model, 'model'):
75
+ sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
76
+ del sd_model.cond_stage_model.transformer
77
+
78
+ if use_medvram:
79
+ sd_model.model.register_forward_pre_hook(send_me_to_gpu)
80
+ else:
81
+ diff_model = sd_model.model.diffusion_model
82
+
83
+ # the third remaining model is still too big for 4 GB, so we also do the same for its submodules
84
+ # so that only one of them is in GPU at a time
85
+ stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
86
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
87
+ sd_model.model.to(devices.device)
88
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
89
+
90
+ # install hooks for bits of third model
91
+ diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
92
+ for block in diff_model.input_blocks:
93
+ block.register_forward_pre_hook(send_me_to_gpu)
94
+ diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
95
+ for block in diff_model.output_blocks:
96
+ block.register_forward_pre_hook(send_me_to_gpu)
sd/stable-diffusion-webui/modules/mac_specific.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from modules import paths
3
+ from modules.sd_hijack_utils import CondFunc
4
+ from packaging import version
5
+
6
+
7
+ # has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
8
+ # check `getattr` and try it for compatibility
9
+ def check_for_mps() -> bool:
10
+ if not getattr(torch, 'has_mps', False):
11
+ return False
12
+ try:
13
+ torch.zeros(1).to(torch.device("mps"))
14
+ return True
15
+ except Exception:
16
+ return False
17
+ has_mps = check_for_mps()
18
+
19
+
20
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
21
+ def cumsum_fix(input, cumsum_func, *args, **kwargs):
22
+ if input.device.type == 'mps':
23
+ output_dtype = kwargs.get('dtype', input.dtype)
24
+ if output_dtype == torch.int64:
25
+ return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
26
+ elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
27
+ return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
28
+ return cumsum_func(input, *args, **kwargs)
29
+
30
+
31
+ if has_mps:
32
+ # MPS fix for randn in torchsde
33
+ CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
34
+
35
+ if version.parse(torch.__version__) < version.parse("1.13"):
36
+ # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
37
+
38
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
39
+ CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
40
+ lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
41
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
42
+ CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
43
+ lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
44
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
45
+ CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
46
+ elif version.parse(torch.__version__) > version.parse("1.13.1"):
47
+ cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
48
+ cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
49
+ cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
50
+ CondFunc('torch.cumsum', cumsum_fix_func, None)
51
+ CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
52
+ CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
53
+
sd/stable-diffusion-webui/modules/masking.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFilter, ImageOps
2
+
3
+
4
+ def get_crop_region(mask, pad=0):
5
+ """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
6
+ For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
7
+
8
+ h, w = mask.shape
9
+
10
+ crop_left = 0
11
+ for i in range(w):
12
+ if not (mask[:, i] == 0).all():
13
+ break
14
+ crop_left += 1
15
+
16
+ crop_right = 0
17
+ for i in reversed(range(w)):
18
+ if not (mask[:, i] == 0).all():
19
+ break
20
+ crop_right += 1
21
+
22
+ crop_top = 0
23
+ for i in range(h):
24
+ if not (mask[i] == 0).all():
25
+ break
26
+ crop_top += 1
27
+
28
+ crop_bottom = 0
29
+ for i in reversed(range(h)):
30
+ if not (mask[i] == 0).all():
31
+ break
32
+ crop_bottom += 1
33
+
34
+ return (
35
+ int(max(crop_left-pad, 0)),
36
+ int(max(crop_top-pad, 0)),
37
+ int(min(w - crop_right + pad, w)),
38
+ int(min(h - crop_bottom + pad, h))
39
+ )
40
+
41
+
42
+ def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
43
+ """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
44
+ for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
45
+
46
+ x1, y1, x2, y2 = crop_region
47
+
48
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
49
+ ratio_processing = processing_width / processing_height
50
+
51
+ if ratio_crop_region > ratio_processing:
52
+ desired_height = (x2 - x1) / ratio_processing
53
+ desired_height_diff = int(desired_height - (y2-y1))
54
+ y1 -= desired_height_diff//2
55
+ y2 += desired_height_diff - desired_height_diff//2
56
+ if y2 >= image_height:
57
+ diff = y2 - image_height
58
+ y2 -= diff
59
+ y1 -= diff
60
+ if y1 < 0:
61
+ y2 -= y1
62
+ y1 -= y1
63
+ if y2 >= image_height:
64
+ y2 = image_height
65
+ else:
66
+ desired_width = (y2 - y1) * ratio_processing
67
+ desired_width_diff = int(desired_width - (x2-x1))
68
+ x1 -= desired_width_diff//2
69
+ x2 += desired_width_diff - desired_width_diff//2
70
+ if x2 >= image_width:
71
+ diff = x2 - image_width
72
+ x2 -= diff
73
+ x1 -= diff
74
+ if x1 < 0:
75
+ x2 -= x1
76
+ x1 -= x1
77
+ if x2 >= image_width:
78
+ x2 = image_width
79
+
80
+ return x1, y1, x2, y2
81
+
82
+
83
+ def fill(image, mask):
84
+ """fills masked regions with colors from image using blur. Not extremely effective."""
85
+
86
+ image_mod = Image.new('RGBA', (image.width, image.height))
87
+
88
+ image_masked = Image.new('RGBa', (image.width, image.height))
89
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
90
+
91
+ image_masked = image_masked.convert('RGBa')
92
+
93
+ for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
94
+ blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
95
+ for _ in range(repeats):
96
+ image_mod.alpha_composite(blurred)
97
+
98
+ return image_mod.convert("RGB")
99
+
sd/stable-diffusion-webui/modules/memmon.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ from collections import defaultdict
4
+
5
+ import torch
6
+
7
+
8
+ class MemUsageMonitor(threading.Thread):
9
+ run_flag = None
10
+ device = None
11
+ disabled = False
12
+ opts = None
13
+ data = None
14
+
15
+ def __init__(self, name, device, opts):
16
+ threading.Thread.__init__(self)
17
+ self.name = name
18
+ self.device = device
19
+ self.opts = opts
20
+
21
+ self.daemon = True
22
+ self.run_flag = threading.Event()
23
+ self.data = defaultdict(int)
24
+
25
+ try:
26
+ torch.cuda.mem_get_info()
27
+ torch.cuda.memory_stats(self.device)
28
+ except Exception as e: # AMD or whatever
29
+ print(f"Warning: caught exception '{e}', memory monitor disabled")
30
+ self.disabled = True
31
+
32
+ def run(self):
33
+ if self.disabled:
34
+ return
35
+
36
+ while True:
37
+ self.run_flag.wait()
38
+
39
+ torch.cuda.reset_peak_memory_stats()
40
+ self.data.clear()
41
+
42
+ if self.opts.memmon_poll_rate <= 0:
43
+ self.run_flag.clear()
44
+ continue
45
+
46
+ self.data["min_free"] = torch.cuda.mem_get_info()[0]
47
+
48
+ while self.run_flag.is_set():
49
+ free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
50
+ self.data["min_free"] = min(self.data["min_free"], free)
51
+
52
+ time.sleep(1 / self.opts.memmon_poll_rate)
53
+
54
+ def dump_debug(self):
55
+ print(self, 'recorded data:')
56
+ for k, v in self.read().items():
57
+ print(k, -(v // -(1024 ** 2)))
58
+
59
+ print(self, 'raw torch memory stats:')
60
+ tm = torch.cuda.memory_stats(self.device)
61
+ for k, v in tm.items():
62
+ if 'bytes' not in k:
63
+ continue
64
+ print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
65
+
66
+ print(torch.cuda.memory_summary())
67
+
68
+ def monitor(self):
69
+ self.run_flag.set()
70
+
71
+ def read(self):
72
+ if not self.disabled:
73
+ free, total = torch.cuda.mem_get_info()
74
+ self.data["free"] = free
75
+ self.data["total"] = total
76
+
77
+ torch_stats = torch.cuda.memory_stats(self.device)
78
+ self.data["active"] = torch_stats["active.all.current"]
79
+ self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
80
+ self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
81
+ self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
82
+ self.data["system_peak"] = total - self.data["min_free"]
83
+
84
+ return self.data
85
+
86
+ def stop(self):
87
+ self.run_flag.clear()
88
+ return self.read()
sd/stable-diffusion-webui/modules/modelloader.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import shutil
4
+ import importlib
5
+ from urllib.parse import urlparse
6
+
7
+ from basicsr.utils.download_util import load_file_from_url
8
+ from modules import shared
9
+ from modules.upscaler import Upscaler
10
+ from modules.paths import script_path, models_path
11
+
12
+
13
+ def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
14
+ """
15
+ A one-and done loader to try finding the desired models in specified directories.
16
+
17
+ @param download_name: Specify to download from model_url immediately.
18
+ @param model_url: If no other models are found, this will be downloaded on upscale.
19
+ @param model_path: The location to store/find models in.
20
+ @param command_path: A command-line argument to search for models in first.
21
+ @param ext_filter: An optional list of filename extensions to filter by
22
+ @return: A list of paths containing the desired model(s)
23
+ """
24
+ output = []
25
+
26
+ if ext_filter is None:
27
+ ext_filter = []
28
+
29
+ try:
30
+ places = []
31
+
32
+ if command_path is not None and command_path != model_path:
33
+ pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
34
+ if os.path.exists(pretrained_path):
35
+ print(f"Appending path: {pretrained_path}")
36
+ places.append(pretrained_path)
37
+ elif os.path.exists(command_path):
38
+ places.append(command_path)
39
+
40
+ places.append(model_path)
41
+
42
+ for place in places:
43
+ if os.path.exists(place):
44
+ for file in glob.iglob(place + '**/**', recursive=True):
45
+ full_path = file
46
+ if os.path.isdir(full_path):
47
+ continue
48
+ if os.path.islink(full_path) and not os.path.exists(full_path):
49
+ print(f"Skipping broken symlink: {full_path}")
50
+ continue
51
+ if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
52
+ continue
53
+ if len(ext_filter) != 0:
54
+ model_name, extension = os.path.splitext(file)
55
+ if extension not in ext_filter:
56
+ continue
57
+ if file not in output:
58
+ output.append(full_path)
59
+
60
+ if model_url is not None and len(output) == 0:
61
+ if download_name is not None:
62
+ dl = load_file_from_url(model_url, model_path, True, download_name)
63
+ output.append(dl)
64
+ else:
65
+ output.append(model_url)
66
+
67
+ except Exception:
68
+ pass
69
+
70
+ return output
71
+
72
+
73
+ def friendly_name(file: str):
74
+ if "http" in file:
75
+ file = urlparse(file).path
76
+
77
+ file = os.path.basename(file)
78
+ model_name, extension = os.path.splitext(file)
79
+ return model_name
80
+
81
+
82
+ def cleanup_models():
83
+ # This code could probably be more efficient if we used a tuple list or something to store the src/destinations
84
+ # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
85
+ # somehow auto-register and just do these things...
86
+ root_path = script_path
87
+ src_path = models_path
88
+ dest_path = os.path.join(models_path, "Stable-diffusion")
89
+ move_files(src_path, dest_path, ".ckpt")
90
+ move_files(src_path, dest_path, ".safetensors")
91
+ src_path = os.path.join(root_path, "ESRGAN")
92
+ dest_path = os.path.join(models_path, "ESRGAN")
93
+ move_files(src_path, dest_path)
94
+ src_path = os.path.join(models_path, "BSRGAN")
95
+ dest_path = os.path.join(models_path, "ESRGAN")
96
+ move_files(src_path, dest_path, ".pth")
97
+ src_path = os.path.join(root_path, "gfpgan")
98
+ dest_path = os.path.join(models_path, "GFPGAN")
99
+ move_files(src_path, dest_path)
100
+ src_path = os.path.join(root_path, "SwinIR")
101
+ dest_path = os.path.join(models_path, "SwinIR")
102
+ move_files(src_path, dest_path)
103
+ src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
104
+ dest_path = os.path.join(models_path, "LDSR")
105
+ move_files(src_path, dest_path)
106
+
107
+
108
+ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
109
+ try:
110
+ if not os.path.exists(dest_path):
111
+ os.makedirs(dest_path)
112
+ if os.path.exists(src_path):
113
+ for file in os.listdir(src_path):
114
+ fullpath = os.path.join(src_path, file)
115
+ if os.path.isfile(fullpath):
116
+ if ext_filter is not None:
117
+ if ext_filter not in file:
118
+ continue
119
+ print(f"Moving {file} from {src_path} to {dest_path}.")
120
+ try:
121
+ shutil.move(fullpath, dest_path)
122
+ except:
123
+ pass
124
+ if len(os.listdir(src_path)) == 0:
125
+ print(f"Removing empty folder: {src_path}")
126
+ shutil.rmtree(src_path, True)
127
+ except:
128
+ pass
129
+
130
+
131
+ builtin_upscaler_classes = []
132
+ forbidden_upscaler_classes = set()
133
+
134
+
135
+ def list_builtin_upscalers():
136
+ load_upscalers()
137
+
138
+ builtin_upscaler_classes.clear()
139
+ builtin_upscaler_classes.extend(Upscaler.__subclasses__())
140
+
141
+
142
+ def forbid_loaded_nonbuiltin_upscalers():
143
+ for cls in Upscaler.__subclasses__():
144
+ if cls not in builtin_upscaler_classes:
145
+ forbidden_upscaler_classes.add(cls)
146
+
147
+
148
+ def load_upscalers():
149
+ # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
150
+ # so we'll try to import any _model.py files before looking in __subclasses__
151
+ modules_dir = os.path.join(shared.script_path, "modules")
152
+ for file in os.listdir(modules_dir):
153
+ if "_model.py" in file:
154
+ model_name = file.replace("_model.py", "")
155
+ full_model = f"modules.{model_name}_model"
156
+ try:
157
+ importlib.import_module(full_model)
158
+ except:
159
+ pass
160
+
161
+ datas = []
162
+ commandline_options = vars(shared.cmd_opts)
163
+ for cls in Upscaler.__subclasses__():
164
+ if cls in forbidden_upscaler_classes:
165
+ continue
166
+
167
+ name = cls.__name__
168
+ cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
169
+ scaler = cls(commandline_options.get(cmd_name, None))
170
+ datas += scaler.scalers
171
+
172
+ shared.sd_upscalers = datas
sd/stable-diffusion-webui/modules/ngrok.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pyngrok import ngrok, conf, exception
2
+
3
+ def connect(token, port, region):
4
+ account = None
5
+ if token is None:
6
+ token = 'None'
7
+ else:
8
+ if ':' in token:
9
+ # token = authtoken:username:password
10
+ account = token.split(':')[1] + ':' + token.split(':')[-1]
11
+ token = token.split(':')[0]
12
+
13
+ config = conf.PyngrokConfig(
14
+ auth_token=token, region=region
15
+ )
16
+ try:
17
+ if account is None:
18
+ public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
19
+ else:
20
+ public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
21
+ except exception.PyngrokNgrokError:
22
+ print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
23
+ f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
24
+ else:
25
+ print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
26
+ 'You can use this link after the launch is complete.')
sd/stable-diffusion-webui/modules/paths.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import modules.safe
5
+
6
+ script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
7
+
8
+ # Parse the --data-dir flag first so we can use it as a base for our other argument default values
9
+ parser = argparse.ArgumentParser(add_help=False)
10
+ parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
11
+ cmd_opts_pre = parser.parse_known_args()[0]
12
+ data_path = cmd_opts_pre.data_dir
13
+ models_path = os.path.join(data_path, "models")
14
+
15
+ # data_path = cmd_opts_pre.data
16
+ sys.path.insert(0, script_path)
17
+
18
+ # search for directory of stable diffusion in following places
19
+ sd_path = None
20
+ possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
21
+ for possible_sd_path in possible_sd_paths:
22
+ if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
23
+ sd_path = os.path.abspath(possible_sd_path)
24
+ break
25
+
26
+ assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
27
+
28
+ path_dirs = [
29
+ (sd_path, 'ldm', 'Stable Diffusion', []),
30
+ (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
31
+ (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
32
+ (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
33
+ (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
34
+ ]
35
+
36
+ paths = {}
37
+
38
+ for d, must_exist, what, options in path_dirs:
39
+ must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
40
+ if not os.path.exists(must_exist_path):
41
+ print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
42
+ else:
43
+ d = os.path.abspath(d)
44
+ if "atstart" in options:
45
+ sys.path.insert(0, d)
46
+ else:
47
+ sys.path.append(d)
48
+ paths[what] = d
49
+
50
+
51
+ class Prioritize:
52
+ def __init__(self, name):
53
+ self.name = name
54
+ self.path = None
55
+
56
+ def __enter__(self):
57
+ self.path = sys.path.copy()
58
+ sys.path = [paths[self.name]] + sys.path
59
+
60
+ def __exit__(self, exc_type, exc_val, exc_tb):
61
+ sys.path = self.path
62
+ self.path = None
sd/stable-diffusion-webui/modules/postprocessing.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from PIL import Image
4
+
5
+ from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
6
+ from modules.shared import opts
7
+
8
+
9
+ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
10
+ devices.torch_gc()
11
+
12
+ shared.state.begin()
13
+ shared.state.job = 'extras'
14
+
15
+ image_data = []
16
+ image_names = []
17
+ outputs = []
18
+
19
+ if extras_mode == 1:
20
+ for img in image_folder:
21
+ image = Image.open(img)
22
+ image_data.append(image)
23
+ image_names.append(os.path.splitext(img.orig_name)[0])
24
+ elif extras_mode == 2:
25
+ assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
26
+ assert input_dir, 'input directory not selected'
27
+
28
+ image_list = shared.listfiles(input_dir)
29
+ for filename in image_list:
30
+ try:
31
+ image = Image.open(filename)
32
+ except Exception:
33
+ continue
34
+ image_data.append(image)
35
+ image_names.append(filename)
36
+ else:
37
+ assert image, 'image not selected'
38
+
39
+ image_data.append(image)
40
+ image_names.append(None)
41
+
42
+ if extras_mode == 2 and output_dir != '':
43
+ outpath = output_dir
44
+ else:
45
+ outpath = opts.outdir_samples or opts.outdir_extras_samples
46
+
47
+ infotext = ''
48
+
49
+ for image, name in zip(image_data, image_names):
50
+ shared.state.textinfo = name
51
+
52
+ existing_pnginfo = image.info or {}
53
+
54
+ pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
55
+
56
+ scripts.scripts_postproc.run(pp, args)
57
+
58
+ if opts.use_original_name_batch and name is not None:
59
+ basename = os.path.splitext(os.path.basename(name))[0]
60
+ else:
61
+ basename = ''
62
+
63
+ infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
64
+
65
+ if opts.enable_pnginfo:
66
+ pp.image.info = existing_pnginfo
67
+ pp.image.info["postprocessing"] = infotext
68
+
69
+ if save_output:
70
+ images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
71
+
72
+ if extras_mode != 2 or show_extras_results:
73
+ outputs.append(pp.image)
74
+
75
+ devices.torch_gc()
76
+
77
+ return outputs, ui_common.plaintext_to_html(infotext), ''
78
+
79
+
80
+ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
81
+ """old handler for API"""
82
+
83
+ args = scripts.scripts_postproc.create_args_for_run({
84
+ "Upscale": {
85
+ "upscale_mode": resize_mode,
86
+ "upscale_by": upscaling_resize,
87
+ "upscale_to_width": upscaling_resize_w,
88
+ "upscale_to_height": upscaling_resize_h,
89
+ "upscale_crop": upscaling_crop,
90
+ "upscaler_1_name": extras_upscaler_1,
91
+ "upscaler_2_name": extras_upscaler_2,
92
+ "upscaler_2_visibility": extras_upscaler_2_visibility,
93
+ },
94
+ "GFPGAN": {
95
+ "gfpgan_visibility": gfpgan_visibility,
96
+ },
97
+ "CodeFormer": {
98
+ "codeformer_visibility": codeformer_visibility,
99
+ "codeformer_weight": codeformer_weight,
100
+ },
101
+ })
102
+
103
+ return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
sd/stable-diffusion-webui/modules/processing.py ADDED
@@ -0,0 +1,1056 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import sys
5
+ import warnings
6
+
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image, ImageFilter, ImageOps
10
+ import random
11
+ import cv2
12
+ from skimage import exposure
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ import modules.sd_hijack
16
+ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
17
+ from modules.sd_hijack import model_hijack
18
+ from modules.shared import opts, cmd_opts, state
19
+ import modules.shared as shared
20
+ import modules.paths as paths
21
+ import modules.face_restoration
22
+ import modules.images as images
23
+ import modules.styles
24
+ import modules.sd_models as sd_models
25
+ import modules.sd_vae as sd_vae
26
+ import logging
27
+ from ldm.data.util import AddMiDaS
28
+ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
29
+
30
+ from einops import repeat, rearrange
31
+ from blendmodes.blend import blendLayers, BlendType
32
+
33
+ # some of those options should not be changed at all because they would break the model, so I removed them from options.
34
+ opt_C = 4
35
+ opt_f = 8
36
+
37
+
38
+ def setup_color_correction(image):
39
+ logging.info("Calibrating color correction.")
40
+ correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
41
+ return correction_target
42
+
43
+
44
+ def apply_color_correction(correction, original_image):
45
+ logging.info("Applying color correction.")
46
+ image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
47
+ cv2.cvtColor(
48
+ np.asarray(original_image),
49
+ cv2.COLOR_RGB2LAB
50
+ ),
51
+ correction,
52
+ channel_axis=2
53
+ ), cv2.COLOR_LAB2RGB).astype("uint8"))
54
+
55
+ image = blendLayers(image, original_image, BlendType.LUMINOSITY)
56
+
57
+ return image
58
+
59
+
60
+ def apply_overlay(image, paste_loc, index, overlays):
61
+ if overlays is None or index >= len(overlays):
62
+ return image
63
+
64
+ overlay = overlays[index]
65
+
66
+ if paste_loc is not None:
67
+ x, y, w, h = paste_loc
68
+ base_image = Image.new('RGBA', (overlay.width, overlay.height))
69
+ image = images.resize_image(1, image, w, h)
70
+ base_image.paste(image, (x, y))
71
+ image = base_image
72
+
73
+ image = image.convert('RGBA')
74
+ image.alpha_composite(overlay)
75
+ image = image.convert('RGB')
76
+
77
+ return image
78
+
79
+
80
+ def txt2img_image_conditioning(sd_model, x, width, height):
81
+ if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
82
+ # Dummy zero conditioning if we're not using inpainting model.
83
+ # Still takes up a bit of memory, but no encoder call.
84
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
85
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
86
+
87
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
88
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
89
+ image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
90
+
91
+ # Add the fake full 1s mask to the first dimension.
92
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
93
+ image_conditioning = image_conditioning.to(x.dtype)
94
+
95
+ return image_conditioning
96
+
97
+
98
+ class StableDiffusionProcessing:
99
+ """
100
+ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
101
+ """
102
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
103
+ if sampler_index is not None:
104
+ print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
105
+
106
+ self.outpath_samples: str = outpath_samples
107
+ self.outpath_grids: str = outpath_grids
108
+ self.prompt: str = prompt
109
+ self.prompt_for_display: str = None
110
+ self.negative_prompt: str = (negative_prompt or "")
111
+ self.styles: list = styles or []
112
+ self.seed: int = seed
113
+ self.subseed: int = subseed
114
+ self.subseed_strength: float = subseed_strength
115
+ self.seed_resize_from_h: int = seed_resize_from_h
116
+ self.seed_resize_from_w: int = seed_resize_from_w
117
+ self.sampler_name: str = sampler_name
118
+ self.batch_size: int = batch_size
119
+ self.n_iter: int = n_iter
120
+ self.steps: int = steps
121
+ self.cfg_scale: float = cfg_scale
122
+ self.width: int = width
123
+ self.height: int = height
124
+ self.restore_faces: bool = restore_faces
125
+ self.tiling: bool = tiling
126
+ self.do_not_save_samples: bool = do_not_save_samples
127
+ self.do_not_save_grid: bool = do_not_save_grid
128
+ self.extra_generation_params: dict = extra_generation_params or {}
129
+ self.overlay_images = overlay_images
130
+ self.eta = eta
131
+ self.do_not_reload_embeddings = do_not_reload_embeddings
132
+ self.paste_to = None
133
+ self.color_corrections = None
134
+ self.denoising_strength: float = denoising_strength
135
+ self.sampler_noise_scheduler_override = None
136
+ self.ddim_discretize = ddim_discretize or opts.ddim_discretize
137
+ self.s_churn = s_churn or opts.s_churn
138
+ self.s_tmin = s_tmin or opts.s_tmin
139
+ self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
140
+ self.s_noise = s_noise or opts.s_noise
141
+ self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
142
+ self.override_settings_restore_afterwards = override_settings_restore_afterwards
143
+ self.is_using_inpainting_conditioning = False
144
+ self.disable_extra_networks = False
145
+
146
+ if not seed_enable_extras:
147
+ self.subseed = -1
148
+ self.subseed_strength = 0
149
+ self.seed_resize_from_h = 0
150
+ self.seed_resize_from_w = 0
151
+
152
+ self.scripts = None
153
+ self.script_args = script_args
154
+ self.all_prompts = None
155
+ self.all_negative_prompts = None
156
+ self.all_seeds = None
157
+ self.all_subseeds = None
158
+ self.iteration = 0
159
+
160
+ @property
161
+ def sd_model(self):
162
+ return shared.sd_model
163
+
164
+ def txt2img_image_conditioning(self, x, width=None, height=None):
165
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
166
+
167
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
168
+
169
+ def depth2img_image_conditioning(self, source_image):
170
+ # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
171
+ transformer = AddMiDaS(model_type="dpt_hybrid")
172
+ transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
173
+ midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
174
+ midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
175
+
176
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
177
+ conditioning = torch.nn.functional.interpolate(
178
+ self.sd_model.depth_model(midas_in),
179
+ size=conditioning_image.shape[2:],
180
+ mode="bicubic",
181
+ align_corners=False,
182
+ )
183
+
184
+ (depth_min, depth_max) = torch.aminmax(conditioning)
185
+ conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
186
+ return conditioning
187
+
188
+ def edit_image_conditioning(self, source_image):
189
+ conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
190
+
191
+ return conditioning_image
192
+
193
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
194
+ self.is_using_inpainting_conditioning = True
195
+
196
+ # Handle the different mask inputs
197
+ if image_mask is not None:
198
+ if torch.is_tensor(image_mask):
199
+ conditioning_mask = image_mask
200
+ else:
201
+ conditioning_mask = np.array(image_mask.convert("L"))
202
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
203
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
204
+
205
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
206
+ conditioning_mask = torch.round(conditioning_mask)
207
+ else:
208
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
209
+
210
+ # Create another latent image, this time with a masked version of the original input.
211
+ # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
212
+ conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
213
+ conditioning_image = torch.lerp(
214
+ source_image,
215
+ source_image * (1.0 - conditioning_mask),
216
+ getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
217
+ )
218
+
219
+ # Encode the new masked image using first stage of network.
220
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
221
+
222
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
223
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
224
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
225
+ image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
226
+ image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
227
+
228
+ return image_conditioning
229
+
230
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
231
+ source_image = devices.cond_cast_float(source_image)
232
+
233
+ # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
234
+ # identify itself with a field common to all models. The conditioning_key is also hybrid.
235
+ if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
236
+ return self.depth2img_image_conditioning(source_image)
237
+
238
+ if self.sd_model.cond_stage_key == "edit":
239
+ return self.edit_image_conditioning(source_image)
240
+
241
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
242
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
243
+
244
+ # Dummy zero conditioning if we're not using inpainting or depth model.
245
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
246
+
247
+ def init(self, all_prompts, all_seeds, all_subseeds):
248
+ pass
249
+
250
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
251
+ raise NotImplementedError()
252
+
253
+ def close(self):
254
+ self.sampler = None
255
+
256
+
257
+ class Processed:
258
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
259
+ self.images = images_list
260
+ self.prompt = p.prompt
261
+ self.negative_prompt = p.negative_prompt
262
+ self.seed = seed
263
+ self.subseed = subseed
264
+ self.subseed_strength = p.subseed_strength
265
+ self.info = info
266
+ self.comments = comments
267
+ self.width = p.width
268
+ self.height = p.height
269
+ self.sampler_name = p.sampler_name
270
+ self.cfg_scale = p.cfg_scale
271
+ self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
272
+ self.steps = p.steps
273
+ self.batch_size = p.batch_size
274
+ self.restore_faces = p.restore_faces
275
+ self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
276
+ self.sd_model_hash = shared.sd_model.sd_model_hash
277
+ self.seed_resize_from_w = p.seed_resize_from_w
278
+ self.seed_resize_from_h = p.seed_resize_from_h
279
+ self.denoising_strength = getattr(p, 'denoising_strength', None)
280
+ self.extra_generation_params = p.extra_generation_params
281
+ self.index_of_first_image = index_of_first_image
282
+ self.styles = p.styles
283
+ self.job_timestamp = state.job_timestamp
284
+ self.clip_skip = opts.CLIP_stop_at_last_layers
285
+
286
+ self.eta = p.eta
287
+ self.ddim_discretize = p.ddim_discretize
288
+ self.s_churn = p.s_churn
289
+ self.s_tmin = p.s_tmin
290
+ self.s_tmax = p.s_tmax
291
+ self.s_noise = p.s_noise
292
+ self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
293
+ self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
294
+ self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
295
+ self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
296
+ self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
297
+ self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
298
+
299
+ self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
300
+ self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
301
+ self.all_seeds = all_seeds or p.all_seeds or [self.seed]
302
+ self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
303
+ self.infotexts = infotexts or [info]
304
+
305
+ def js(self):
306
+ obj = {
307
+ "prompt": self.all_prompts[0],
308
+ "all_prompts": self.all_prompts,
309
+ "negative_prompt": self.all_negative_prompts[0],
310
+ "all_negative_prompts": self.all_negative_prompts,
311
+ "seed": self.seed,
312
+ "all_seeds": self.all_seeds,
313
+ "subseed": self.subseed,
314
+ "all_subseeds": self.all_subseeds,
315
+ "subseed_strength": self.subseed_strength,
316
+ "width": self.width,
317
+ "height": self.height,
318
+ "sampler_name": self.sampler_name,
319
+ "cfg_scale": self.cfg_scale,
320
+ "steps": self.steps,
321
+ "batch_size": self.batch_size,
322
+ "restore_faces": self.restore_faces,
323
+ "face_restoration_model": self.face_restoration_model,
324
+ "sd_model_hash": self.sd_model_hash,
325
+ "seed_resize_from_w": self.seed_resize_from_w,
326
+ "seed_resize_from_h": self.seed_resize_from_h,
327
+ "denoising_strength": self.denoising_strength,
328
+ "extra_generation_params": self.extra_generation_params,
329
+ "index_of_first_image": self.index_of_first_image,
330
+ "infotexts": self.infotexts,
331
+ "styles": self.styles,
332
+ "job_timestamp": self.job_timestamp,
333
+ "clip_skip": self.clip_skip,
334
+ "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
335
+ }
336
+
337
+ return json.dumps(obj)
338
+
339
+ def infotext(self, p: StableDiffusionProcessing, index):
340
+ return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
341
+
342
+
343
+ # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
344
+ def slerp(val, low, high):
345
+ low_norm = low/torch.norm(low, dim=1, keepdim=True)
346
+ high_norm = high/torch.norm(high, dim=1, keepdim=True)
347
+ dot = (low_norm*high_norm).sum(1)
348
+
349
+ if dot.mean() > 0.9995:
350
+ return low * val + high * (1 - val)
351
+
352
+ omega = torch.acos(dot)
353
+ so = torch.sin(omega)
354
+ res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
355
+ return res
356
+
357
+
358
+ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
359
+ eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
360
+ xs = []
361
+
362
+ # if we have multiple seeds, this means we are working with batch size>1; this then
363
+ # enables the generation of additional tensors with noise that the sampler will use during its processing.
364
+ # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
365
+ # produce the same images as with two batches [100], [101].
366
+ if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
367
+ sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
368
+ else:
369
+ sampler_noises = None
370
+
371
+ for i, seed in enumerate(seeds):
372
+ noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
373
+
374
+ subnoise = None
375
+ if subseeds is not None:
376
+ subseed = 0 if i >= len(subseeds) else subseeds[i]
377
+
378
+ subnoise = devices.randn(subseed, noise_shape)
379
+
380
+ # randn results depend on device; gpu and cpu get different results for same seed;
381
+ # the way I see it, it's better to do this on CPU, so that everyone gets same result;
382
+ # but the original script had it like this, so I do not dare change it for now because
383
+ # it will break everyone's seeds.
384
+ noise = devices.randn(seed, noise_shape)
385
+
386
+ if subnoise is not None:
387
+ noise = slerp(subseed_strength, noise, subnoise)
388
+
389
+ if noise_shape != shape:
390
+ x = devices.randn(seed, shape)
391
+ dx = (shape[2] - noise_shape[2]) // 2
392
+ dy = (shape[1] - noise_shape[1]) // 2
393
+ w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
394
+ h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
395
+ tx = 0 if dx < 0 else dx
396
+ ty = 0 if dy < 0 else dy
397
+ dx = max(-dx, 0)
398
+ dy = max(-dy, 0)
399
+
400
+ x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
401
+ noise = x
402
+
403
+ if sampler_noises is not None:
404
+ cnt = p.sampler.number_of_needed_noises(p)
405
+
406
+ if eta_noise_seed_delta > 0:
407
+ torch.manual_seed(seed + eta_noise_seed_delta)
408
+
409
+ for j in range(cnt):
410
+ sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
411
+
412
+ xs.append(noise)
413
+
414
+ if sampler_noises is not None:
415
+ p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
416
+
417
+ x = torch.stack(xs).to(shared.device)
418
+ return x
419
+
420
+
421
+ def decode_first_stage(model, x):
422
+ with devices.autocast(disable=x.dtype == devices.dtype_vae):
423
+ x = model.decode_first_stage(x)
424
+
425
+ return x
426
+
427
+
428
+ def get_fixed_seed(seed):
429
+ if seed is None or seed == '' or seed == -1:
430
+ return int(random.randrange(4294967294))
431
+
432
+ return seed
433
+
434
+
435
+ def fix_seed(p):
436
+ p.seed = get_fixed_seed(p.seed)
437
+ p.subseed = get_fixed_seed(p.subseed)
438
+
439
+
440
+ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
441
+ index = position_in_batch + iteration * p.batch_size
442
+
443
+ clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
444
+
445
+ generation_params = {
446
+ "Steps": p.steps,
447
+ "Sampler": p.sampler_name,
448
+ "CFG scale": p.cfg_scale,
449
+ "Image CFG scale": getattr(p, 'image_cfg_scale', None),
450
+ "Seed": all_seeds[index],
451
+ "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
452
+ "Size": f"{p.width}x{p.height}",
453
+ "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
454
+ "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
455
+ "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
456
+ "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
457
+ "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
458
+ "Denoising strength": getattr(p, 'denoising_strength', None),
459
+ "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
460
+ "Clip skip": None if clip_skip <= 1 else clip_skip,
461
+ "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
462
+ }
463
+
464
+ generation_params.update(p.extra_generation_params)
465
+
466
+ generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
467
+
468
+ negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
469
+
470
+ return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
471
+
472
+
473
+ def process_images(p: StableDiffusionProcessing) -> Processed:
474
+ stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
475
+
476
+ try:
477
+ for k, v in p.override_settings.items():
478
+ setattr(opts, k, v)
479
+
480
+ if k == 'sd_model_checkpoint':
481
+ sd_models.reload_model_weights()
482
+
483
+ if k == 'sd_vae':
484
+ sd_vae.reload_vae_weights()
485
+
486
+ res = process_images_inner(p)
487
+
488
+ finally:
489
+ # restore opts to original state
490
+ if p.override_settings_restore_afterwards:
491
+ for k, v in stored_opts.items():
492
+ setattr(opts, k, v)
493
+ if k == 'sd_model_checkpoint':
494
+ sd_models.reload_model_weights()
495
+
496
+ if k == 'sd_vae':
497
+ sd_vae.reload_vae_weights()
498
+
499
+ return res
500
+
501
+
502
+ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
503
+ """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
504
+
505
+ if type(p.prompt) == list:
506
+ assert(len(p.prompt) > 0)
507
+ else:
508
+ assert p.prompt is not None
509
+
510
+ devices.torch_gc()
511
+
512
+ seed = get_fixed_seed(p.seed)
513
+ subseed = get_fixed_seed(p.subseed)
514
+
515
+ modules.sd_hijack.model_hijack.apply_circular(p.tiling)
516
+ modules.sd_hijack.model_hijack.clear_comments()
517
+
518
+ comments = {}
519
+
520
+ if type(p.prompt) == list:
521
+ p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
522
+ else:
523
+ p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
524
+
525
+ if type(p.negative_prompt) == list:
526
+ p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
527
+ else:
528
+ p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
529
+
530
+ if type(seed) == list:
531
+ p.all_seeds = seed
532
+ else:
533
+ p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
534
+
535
+ if type(subseed) == list:
536
+ p.all_subseeds = subseed
537
+ else:
538
+ p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
539
+
540
+ def infotext(iteration=0, position_in_batch=0):
541
+ return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
542
+
543
+ if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
544
+ model_hijack.embedding_db.load_textual_inversion_embeddings()
545
+
546
+ if p.scripts is not None:
547
+ p.scripts.process(p)
548
+
549
+ infotexts = []
550
+ output_images = []
551
+
552
+ cached_uc = [None, None]
553
+ cached_c = [None, None]
554
+
555
+ def get_conds_with_caching(function, required_prompts, steps, cache):
556
+ """
557
+ Returns the result of calling function(shared.sd_model, required_prompts, steps)
558
+ using a cache to store the result if the same arguments have been used before.
559
+
560
+ cache is an array containing two elements. The first element is a tuple
561
+ representing the previously used arguments, or None if no arguments
562
+ have been used before. The second element is where the previously
563
+ computed result is stored.
564
+ """
565
+
566
+ if cache[0] is not None and (required_prompts, steps) == cache[0]:
567
+ return cache[1]
568
+
569
+ with devices.autocast():
570
+ cache[1] = function(shared.sd_model, required_prompts, steps)
571
+
572
+ cache[0] = (required_prompts, steps)
573
+ return cache[1]
574
+
575
+ with torch.no_grad(), p.sd_model.ema_scope():
576
+ with devices.autocast():
577
+ p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
578
+
579
+ # for OSX, loading the model during sampling changes the generated picture, so it is loaded here
580
+ if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
581
+ sd_vae_approx.model()
582
+
583
+ if state.job_count == -1:
584
+ state.job_count = p.n_iter
585
+
586
+ for n in range(p.n_iter):
587
+ p.iteration = n
588
+
589
+ if state.skipped:
590
+ state.skipped = False
591
+
592
+ if state.interrupted:
593
+ break
594
+
595
+ prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
596
+ negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
597
+ seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
598
+ subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
599
+
600
+ if len(prompts) == 0:
601
+ break
602
+
603
+ prompts, extra_network_data = extra_networks.parse_prompts(prompts)
604
+
605
+ if not p.disable_extra_networks:
606
+ with devices.autocast():
607
+ extra_networks.activate(p, extra_network_data)
608
+
609
+ if p.scripts is not None:
610
+ p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
611
+
612
+ # params.txt should be saved after scripts.process_batch, since the
613
+ # infotext could be modified by that callback
614
+ # Example: a wildcard processed by process_batch sets an extra model
615
+ # strength, which is saved as "Model Strength: 1.0" in the infotext
616
+ if n == 0:
617
+ with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
618
+ processed = Processed(p, [], p.seed, "")
619
+ file.write(processed.infotext(p, 0))
620
+
621
+ uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
622
+ c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
623
+
624
+ if len(model_hijack.comments) > 0:
625
+ for comment in model_hijack.comments:
626
+ comments[comment] = 1
627
+
628
+ if p.n_iter > 1:
629
+ shared.state.job = f"Batch {n+1} out of {p.n_iter}"
630
+
631
+ with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
632
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
633
+
634
+ x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
635
+ for x in x_samples_ddim:
636
+ devices.test_for_nans(x, "vae")
637
+
638
+ x_samples_ddim = torch.stack(x_samples_ddim).float()
639
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
640
+
641
+ del samples_ddim
642
+
643
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
644
+ lowvram.send_everything_to_cpu()
645
+
646
+ devices.torch_gc()
647
+
648
+ if p.scripts is not None:
649
+ p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
650
+
651
+ for i, x_sample in enumerate(x_samples_ddim):
652
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
653
+ x_sample = x_sample.astype(np.uint8)
654
+
655
+ if p.restore_faces:
656
+ if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
657
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
658
+
659
+ devices.torch_gc()
660
+
661
+ x_sample = modules.face_restoration.restore_faces(x_sample)
662
+ devices.torch_gc()
663
+
664
+ image = Image.fromarray(x_sample)
665
+
666
+ if p.scripts is not None:
667
+ pp = scripts.PostprocessImageArgs(image)
668
+ p.scripts.postprocess_image(p, pp)
669
+ image = pp.image
670
+
671
+ if p.color_corrections is not None and i < len(p.color_corrections):
672
+ if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
673
+ image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
674
+ images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
675
+ image = apply_color_correction(p.color_corrections[i], image)
676
+
677
+ image = apply_overlay(image, p.paste_to, i, p.overlay_images)
678
+
679
+ if opts.samples_save and not p.do_not_save_samples:
680
+ images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
681
+
682
+ text = infotext(n, i)
683
+ infotexts.append(text)
684
+ if opts.enable_pnginfo:
685
+ image.info["parameters"] = text
686
+ output_images.append(image)
687
+
688
+ del x_samples_ddim
689
+
690
+ devices.torch_gc()
691
+
692
+ state.nextjob()
693
+
694
+ p.color_corrections = None
695
+
696
+ index_of_first_image = 0
697
+ unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
698
+ if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
699
+ grid = images.image_grid(output_images, p.batch_size)
700
+
701
+ if opts.return_grid:
702
+ text = infotext()
703
+ infotexts.insert(0, text)
704
+ if opts.enable_pnginfo:
705
+ grid.info["parameters"] = text
706
+ output_images.insert(0, grid)
707
+ index_of_first_image = 1
708
+
709
+ if opts.grid_save:
710
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
711
+
712
+ if not p.disable_extra_networks:
713
+ extra_networks.deactivate(p, extra_network_data)
714
+
715
+ devices.torch_gc()
716
+
717
+ res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
718
+
719
+ if p.scripts is not None:
720
+ p.scripts.postprocess(p, res)
721
+
722
+ return res
723
+
724
+
725
+ def old_hires_fix_first_pass_dimensions(width, height):
726
+ """old algorithm for auto-calculating first pass size"""
727
+
728
+ desired_pixel_count = 512 * 512
729
+ actual_pixel_count = width * height
730
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
731
+ width = math.ceil(scale * width / 64) * 64
732
+ height = math.ceil(scale * height / 64) * 64
733
+
734
+ return width, height
735
+
736
+
737
+ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
738
+ sampler = None
739
+
740
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
741
+ super().__init__(**kwargs)
742
+ self.enable_hr = enable_hr
743
+ self.denoising_strength = denoising_strength
744
+ self.hr_scale = hr_scale
745
+ self.hr_upscaler = hr_upscaler
746
+ self.hr_second_pass_steps = hr_second_pass_steps
747
+ self.hr_resize_x = hr_resize_x
748
+ self.hr_resize_y = hr_resize_y
749
+ self.hr_upscale_to_x = hr_resize_x
750
+ self.hr_upscale_to_y = hr_resize_y
751
+
752
+ if firstphase_width != 0 or firstphase_height != 0:
753
+ self.hr_upscale_to_x = self.width
754
+ self.hr_upscale_to_y = self.height
755
+ self.width = firstphase_width
756
+ self.height = firstphase_height
757
+
758
+ self.truncate_x = 0
759
+ self.truncate_y = 0
760
+ self.applied_old_hires_behavior_to = None
761
+
762
+ def init(self, all_prompts, all_seeds, all_subseeds):
763
+ if self.enable_hr:
764
+ if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
765
+ self.hr_resize_x = self.width
766
+ self.hr_resize_y = self.height
767
+ self.hr_upscale_to_x = self.width
768
+ self.hr_upscale_to_y = self.height
769
+
770
+ self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
771
+ self.applied_old_hires_behavior_to = (self.width, self.height)
772
+
773
+ if self.hr_resize_x == 0 and self.hr_resize_y == 0:
774
+ self.extra_generation_params["Hires upscale"] = self.hr_scale
775
+ self.hr_upscale_to_x = int(self.width * self.hr_scale)
776
+ self.hr_upscale_to_y = int(self.height * self.hr_scale)
777
+ else:
778
+ self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
779
+
780
+ if self.hr_resize_y == 0:
781
+ self.hr_upscale_to_x = self.hr_resize_x
782
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
783
+ elif self.hr_resize_x == 0:
784
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
785
+ self.hr_upscale_to_y = self.hr_resize_y
786
+ else:
787
+ target_w = self.hr_resize_x
788
+ target_h = self.hr_resize_y
789
+ src_ratio = self.width / self.height
790
+ dst_ratio = self.hr_resize_x / self.hr_resize_y
791
+
792
+ if src_ratio < dst_ratio:
793
+ self.hr_upscale_to_x = self.hr_resize_x
794
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
795
+ else:
796
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
797
+ self.hr_upscale_to_y = self.hr_resize_y
798
+
799
+ self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
800
+ self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
801
+
802
+ # special case: the user has chosen to do nothing
803
+ if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
804
+ self.enable_hr = False
805
+ self.denoising_strength = None
806
+ self.extra_generation_params.pop("Hires upscale", None)
807
+ self.extra_generation_params.pop("Hires resize", None)
808
+ return
809
+
810
+ if not state.processing_has_refined_job_count:
811
+ if state.job_count == -1:
812
+ state.job_count = self.n_iter
813
+
814
+ shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
815
+ state.job_count = state.job_count * 2
816
+ state.processing_has_refined_job_count = True
817
+
818
+ if self.hr_second_pass_steps:
819
+ self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
820
+
821
+ if self.hr_upscaler is not None:
822
+ self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
823
+
824
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
825
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
826
+
827
+ latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
828
+ if self.enable_hr and latent_scale_mode is None:
829
+ assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
830
+
831
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
832
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
833
+
834
+ if not self.enable_hr:
835
+ return samples
836
+
837
+ target_width = self.hr_upscale_to_x
838
+ target_height = self.hr_upscale_to_y
839
+
840
+ def save_intermediate(image, index):
841
+ """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
842
+
843
+ if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
844
+ return
845
+
846
+ if not isinstance(image, Image.Image):
847
+ image = sd_samplers.sample_to_image(image, index, approximation=0)
848
+
849
+ info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
850
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
851
+
852
+ if latent_scale_mode is not None:
853
+ for i in range(samples.shape[0]):
854
+ save_intermediate(samples, i)
855
+
856
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
857
+
858
+ # Avoid making the inpainting conditioning unless necessary as
859
+ # this does need some extra compute to decode / encode the image again.
860
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
861
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
862
+ else:
863
+ image_conditioning = self.txt2img_image_conditioning(samples)
864
+ else:
865
+ decoded_samples = decode_first_stage(self.sd_model, samples)
866
+ lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
867
+
868
+ batch_images = []
869
+ for i, x_sample in enumerate(lowres_samples):
870
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
871
+ x_sample = x_sample.astype(np.uint8)
872
+ image = Image.fromarray(x_sample)
873
+
874
+ save_intermediate(image, i)
875
+
876
+ image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
877
+ image = np.array(image).astype(np.float32) / 255.0
878
+ image = np.moveaxis(image, 2, 0)
879
+ batch_images.append(image)
880
+
881
+ decoded_samples = torch.from_numpy(np.array(batch_images))
882
+ decoded_samples = decoded_samples.to(shared.device)
883
+ decoded_samples = 2. * decoded_samples - 1.
884
+
885
+ samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
886
+
887
+ image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
888
+
889
+ shared.state.nextjob()
890
+
891
+ img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
892
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
893
+
894
+ samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
895
+
896
+ noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
897
+
898
+ # GC now before running the next img2img to prevent running out of memory
899
+ x = None
900
+ devices.torch_gc()
901
+
902
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
903
+
904
+ return samples
905
+
906
+
907
+ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
908
+ sampler = None
909
+
910
+ def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
911
+ super().__init__(**kwargs)
912
+
913
+ self.init_images = init_images
914
+ self.resize_mode: int = resize_mode
915
+ self.denoising_strength: float = denoising_strength
916
+ self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
917
+ self.init_latent = None
918
+ self.image_mask = mask
919
+ self.latent_mask = None
920
+ self.mask_for_overlay = None
921
+ self.mask_blur = mask_blur
922
+ self.inpainting_fill = inpainting_fill
923
+ self.inpaint_full_res = inpaint_full_res
924
+ self.inpaint_full_res_padding = inpaint_full_res_padding
925
+ self.inpainting_mask_invert = inpainting_mask_invert
926
+ self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
927
+ self.mask = None
928
+ self.nmask = None
929
+ self.image_conditioning = None
930
+
931
+ def init(self, all_prompts, all_seeds, all_subseeds):
932
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
933
+ crop_region = None
934
+
935
+ image_mask = self.image_mask
936
+
937
+ if image_mask is not None:
938
+ image_mask = image_mask.convert('L')
939
+
940
+ if self.inpainting_mask_invert:
941
+ image_mask = ImageOps.invert(image_mask)
942
+
943
+ if self.mask_blur > 0:
944
+ image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
945
+
946
+ if self.inpaint_full_res:
947
+ self.mask_for_overlay = image_mask
948
+ mask = image_mask.convert('L')
949
+ crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
950
+ crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
951
+ x1, y1, x2, y2 = crop_region
952
+
953
+ mask = mask.crop(crop_region)
954
+ image_mask = images.resize_image(2, mask, self.width, self.height)
955
+ self.paste_to = (x1, y1, x2-x1, y2-y1)
956
+ else:
957
+ image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
958
+ np_mask = np.array(image_mask)
959
+ np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
960
+ self.mask_for_overlay = Image.fromarray(np_mask)
961
+
962
+ self.overlay_images = []
963
+
964
+ latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
965
+
966
+ add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
967
+ if add_color_corrections:
968
+ self.color_corrections = []
969
+ imgs = []
970
+ for img in self.init_images:
971
+ image = images.flatten(img, opts.img2img_background_color)
972
+
973
+ if crop_region is None and self.resize_mode != 3:
974
+ image = images.resize_image(self.resize_mode, image, self.width, self.height)
975
+
976
+ if image_mask is not None:
977
+ image_masked = Image.new('RGBa', (image.width, image.height))
978
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
979
+
980
+ self.overlay_images.append(image_masked.convert('RGBA'))
981
+
982
+ # crop_region is not None if we are doing inpaint full res
983
+ if crop_region is not None:
984
+ image = image.crop(crop_region)
985
+ image = images.resize_image(2, image, self.width, self.height)
986
+
987
+ if image_mask is not None:
988
+ if self.inpainting_fill != 1:
989
+ image = masking.fill(image, latent_mask)
990
+
991
+ if add_color_corrections:
992
+ self.color_corrections.append(setup_color_correction(image))
993
+
994
+ image = np.array(image).astype(np.float32) / 255.0
995
+ image = np.moveaxis(image, 2, 0)
996
+
997
+ imgs.append(image)
998
+
999
+ if len(imgs) == 1:
1000
+ batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
1001
+ if self.overlay_images is not None:
1002
+ self.overlay_images = self.overlay_images * self.batch_size
1003
+
1004
+ if self.color_corrections is not None and len(self.color_corrections) == 1:
1005
+ self.color_corrections = self.color_corrections * self.batch_size
1006
+
1007
+ elif len(imgs) <= self.batch_size:
1008
+ self.batch_size = len(imgs)
1009
+ batch_images = np.array(imgs)
1010
+ else:
1011
+ raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
1012
+
1013
+ image = torch.from_numpy(batch_images)
1014
+ image = 2. * image - 1.
1015
+ image = image.to(shared.device)
1016
+
1017
+ self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
1018
+
1019
+ if self.resize_mode == 3:
1020
+ self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
1021
+
1022
+ if image_mask is not None:
1023
+ init_mask = latent_mask
1024
+ latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
1025
+ latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
1026
+ latmask = latmask[0]
1027
+ latmask = np.around(latmask)
1028
+ latmask = np.tile(latmask[None], (4, 1, 1))
1029
+
1030
+ self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
1031
+ self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
1032
+
1033
+ # this needs to be fixed to be done in sample() using actual seeds for batches
1034
+ if self.inpainting_fill == 2:
1035
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
1036
+ elif self.inpainting_fill == 3:
1037
+ self.init_latent = self.init_latent * self.mask
1038
+
1039
+ self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
1040
+
1041
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1042
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1043
+
1044
+ if self.initial_noise_multiplier != 1.0:
1045
+ self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
1046
+ x *= self.initial_noise_multiplier
1047
+
1048
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
1049
+
1050
+ if self.mask is not None:
1051
+ samples = samples * self.nmask + self.init_latent * self.mask
1052
+
1053
+ del x
1054
+ devices.torch_gc()
1055
+
1056
+ return samples
sd/stable-diffusion-webui/modules/progress.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import time
4
+
5
+ import gradio as gr
6
+ from pydantic import BaseModel, Field
7
+
8
+ from modules.shared import opts
9
+
10
+ import modules.shared as shared
11
+
12
+
13
+ current_task = None
14
+ pending_tasks = {}
15
+ finished_tasks = []
16
+
17
+
18
+ def start_task(id_task):
19
+ global current_task
20
+
21
+ current_task = id_task
22
+ pending_tasks.pop(id_task, None)
23
+
24
+
25
+ def finish_task(id_task):
26
+ global current_task
27
+
28
+ if current_task == id_task:
29
+ current_task = None
30
+
31
+ finished_tasks.append(id_task)
32
+ if len(finished_tasks) > 16:
33
+ finished_tasks.pop(0)
34
+
35
+
36
+ def add_task_to_queue(id_job):
37
+ pending_tasks[id_job] = time.time()
38
+
39
+
40
+ class ProgressRequest(BaseModel):
41
+ id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
42
+ id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
43
+
44
+
45
+ class ProgressResponse(BaseModel):
46
+ active: bool = Field(title="Whether the task is being worked on right now")
47
+ queued: bool = Field(title="Whether the task is in queue")
48
+ completed: bool = Field(title="Whether the task has already finished")
49
+ progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
50
+ eta: float = Field(default=None, title="ETA in secs")
51
+ live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
52
+ id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
53
+ textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
54
+
55
+
56
+ def setup_progress_api(app):
57
+ return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
58
+
59
+
60
+ def progressapi(req: ProgressRequest):
61
+ active = req.id_task == current_task
62
+ queued = req.id_task in pending_tasks
63
+ completed = req.id_task in finished_tasks
64
+
65
+ if not active:
66
+ return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
67
+
68
+ progress = 0
69
+
70
+ job_count, job_no = shared.state.job_count, shared.state.job_no
71
+ sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
72
+
73
+ if job_count > 0:
74
+ progress += job_no / job_count
75
+ if sampling_steps > 0 and job_count > 0:
76
+ progress += 1 / job_count * sampling_step / sampling_steps
77
+
78
+ progress = min(progress, 1)
79
+
80
+ elapsed_since_start = time.time() - shared.state.time_start
81
+ predicted_duration = elapsed_since_start / progress if progress > 0 else None
82
+ eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
83
+
84
+ id_live_preview = req.id_live_preview
85
+ shared.state.set_current_image()
86
+ if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
87
+ image = shared.state.current_image
88
+ if image is not None:
89
+ buffered = io.BytesIO()
90
+ image.save(buffered, format="png")
91
+ live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
92
+ id_live_preview = shared.state.id_live_preview
93
+ else:
94
+ live_preview = None
95
+ else:
96
+ live_preview = None
97
+
98
+ return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
99
+
sd/stable-diffusion-webui/modules/prompt_parser.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import namedtuple
3
+ from typing import List
4
+ import lark
5
+
6
+ # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
7
+ # will be represented with prompt_schedule like this (assuming steps=100):
8
+ # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
9
+ # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
10
+ # [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
11
+ # [75, 'fantasy landscape with a lake and an oak in background masterful']
12
+ # [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
13
+
14
+ schedule_parser = lark.Lark(r"""
15
+ !start: (prompt | /[][():]/+)*
16
+ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
17
+ !emphasized: "(" prompt ")"
18
+ | "(" prompt ":" prompt ")"
19
+ | "[" prompt "]"
20
+ scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
21
+ alternate: "[" prompt ("|" prompt)+ "]"
22
+ WHITESPACE: /\s+/
23
+ plain: /([^\\\[\]():|]|\\.)+/
24
+ %import common.SIGNED_NUMBER -> NUMBER
25
+ """)
26
+
27
+ def get_learned_conditioning_prompt_schedules(prompts, steps):
28
+ """
29
+ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
30
+ >>> g("test")
31
+ [[10, 'test']]
32
+ >>> g("a [b:3]")
33
+ [[3, 'a '], [10, 'a b']]
34
+ >>> g("a [b: 3]")
35
+ [[3, 'a '], [10, 'a b']]
36
+ >>> g("a [[[b]]:2]")
37
+ [[2, 'a '], [10, 'a [[b]]']]
38
+ >>> g("[(a:2):3]")
39
+ [[3, ''], [10, '(a:2)']]
40
+ >>> g("a [b : c : 1] d")
41
+ [[1, 'a b d'], [10, 'a c d']]
42
+ >>> g("a[b:[c:d:2]:1]e")
43
+ [[1, 'abe'], [2, 'ace'], [10, 'ade']]
44
+ >>> g("a [unbalanced")
45
+ [[10, 'a [unbalanced']]
46
+ >>> g("a [b:.5] c")
47
+ [[5, 'a c'], [10, 'a b c']]
48
+ >>> g("a [{b|d{:.5] c") # not handling this right now
49
+ [[5, 'a c'], [10, 'a {b|d{ c']]
50
+ >>> g("((a][:b:c [d:3]")
51
+ [[3, '((a][:b:c '], [10, '((a][:b:c d']]
52
+ >>> g("[a|(b:1.1)]")
53
+ [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
54
+ """
55
+
56
+ def collect_steps(steps, tree):
57
+ l = [steps]
58
+ class CollectSteps(lark.Visitor):
59
+ def scheduled(self, tree):
60
+ tree.children[-1] = float(tree.children[-1])
61
+ if tree.children[-1] < 1:
62
+ tree.children[-1] *= steps
63
+ tree.children[-1] = min(steps, int(tree.children[-1]))
64
+ l.append(tree.children[-1])
65
+ def alternate(self, tree):
66
+ l.extend(range(1, steps+1))
67
+ CollectSteps().visit(tree)
68
+ return sorted(set(l))
69
+
70
+ def at_step(step, tree):
71
+ class AtStep(lark.Transformer):
72
+ def scheduled(self, args):
73
+ before, after, _, when = args
74
+ yield before or () if step <= when else after
75
+ def alternate(self, args):
76
+ yield next(args[(step - 1)%len(args)])
77
+ def start(self, args):
78
+ def flatten(x):
79
+ if type(x) == str:
80
+ yield x
81
+ else:
82
+ for gen in x:
83
+ yield from flatten(gen)
84
+ return ''.join(flatten(args))
85
+ def plain(self, args):
86
+ yield args[0].value
87
+ def __default__(self, data, children, meta):
88
+ for child in children:
89
+ yield child
90
+ return AtStep().transform(tree)
91
+
92
+ def get_schedule(prompt):
93
+ try:
94
+ tree = schedule_parser.parse(prompt)
95
+ except lark.exceptions.LarkError as e:
96
+ if 0:
97
+ import traceback
98
+ traceback.print_exc()
99
+ return [[steps, prompt]]
100
+ return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
101
+
102
+ promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
103
+ return [promptdict[prompt] for prompt in prompts]
104
+
105
+
106
+ ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
107
+
108
+
109
+ def get_learned_conditioning(model, prompts, steps):
110
+ """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
111
+ and the sampling step at which this condition is to be replaced by the next one.
112
+
113
+ Input:
114
+ (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
115
+
116
+ Output:
117
+ [
118
+ [
119
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
120
+ ],
121
+ [
122
+ ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
123
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
124
+ ]
125
+ ]
126
+ """
127
+ res = []
128
+
129
+ prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
130
+ cache = {}
131
+
132
+ for prompt, prompt_schedule in zip(prompts, prompt_schedules):
133
+
134
+ cached = cache.get(prompt, None)
135
+ if cached is not None:
136
+ res.append(cached)
137
+ continue
138
+
139
+ texts = [x[1] for x in prompt_schedule]
140
+ conds = model.get_learned_conditioning(texts)
141
+
142
+ cond_schedule = []
143
+ for i, (end_at_step, text) in enumerate(prompt_schedule):
144
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
145
+
146
+ cache[prompt] = cond_schedule
147
+ res.append(cond_schedule)
148
+
149
+ return res
150
+
151
+
152
+ re_AND = re.compile(r"\bAND\b")
153
+ re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
154
+
155
+ def get_multicond_prompt_list(prompts):
156
+ res_indexes = []
157
+
158
+ prompt_flat_list = []
159
+ prompt_indexes = {}
160
+
161
+ for prompt in prompts:
162
+ subprompts = re_AND.split(prompt)
163
+
164
+ indexes = []
165
+ for subprompt in subprompts:
166
+ match = re_weight.search(subprompt)
167
+
168
+ text, weight = match.groups() if match is not None else (subprompt, 1.0)
169
+
170
+ weight = float(weight) if weight is not None else 1.0
171
+
172
+ index = prompt_indexes.get(text, None)
173
+ if index is None:
174
+ index = len(prompt_flat_list)
175
+ prompt_flat_list.append(text)
176
+ prompt_indexes[text] = index
177
+
178
+ indexes.append((index, weight))
179
+
180
+ res_indexes.append(indexes)
181
+
182
+ return res_indexes, prompt_flat_list, prompt_indexes
183
+
184
+
185
+ class ComposableScheduledPromptConditioning:
186
+ def __init__(self, schedules, weight=1.0):
187
+ self.schedules: List[ScheduledPromptConditioning] = schedules
188
+ self.weight: float = weight
189
+
190
+
191
+ class MulticondLearnedConditioning:
192
+ def __init__(self, shape, batch):
193
+ self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
194
+ self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
195
+
196
+ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
197
+ """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
198
+ For each prompt, the list is obtained by splitting the prompt using the AND separator.
199
+
200
+ https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
201
+ """
202
+
203
+ res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
204
+
205
+ learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
206
+
207
+ res = []
208
+ for indexes in res_indexes:
209
+ res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
210
+
211
+ return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
212
+
213
+
214
+ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
215
+ param = c[0][0].cond
216
+ res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
217
+ for i, cond_schedule in enumerate(c):
218
+ target_index = 0
219
+ for current, (end_at, cond) in enumerate(cond_schedule):
220
+ if current_step <= end_at:
221
+ target_index = current
222
+ break
223
+ res[i] = cond_schedule[target_index].cond
224
+
225
+ return res
226
+
227
+
228
+ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
229
+ param = c.batch[0][0].schedules[0].cond
230
+
231
+ tensors = []
232
+ conds_list = []
233
+
234
+ for batch_no, composable_prompts in enumerate(c.batch):
235
+ conds_for_batch = []
236
+
237
+ for cond_index, composable_prompt in enumerate(composable_prompts):
238
+ target_index = 0
239
+ for current, (end_at, cond) in enumerate(composable_prompt.schedules):
240
+ if current_step <= end_at:
241
+ target_index = current
242
+ break
243
+
244
+ conds_for_batch.append((len(tensors), composable_prompt.weight))
245
+ tensors.append(composable_prompt.schedules[target_index].cond)
246
+
247
+ conds_list.append(conds_for_batch)
248
+
249
+ # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
250
+ # and won't be able to torch.stack them. So this fixes that.
251
+ token_count = max([x.shape[0] for x in tensors])
252
+ for i in range(len(tensors)):
253
+ if tensors[i].shape[0] != token_count:
254
+ last_vector = tensors[i][-1:]
255
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
256
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
257
+
258
+ return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
259
+
260
+
261
+ re_attention = re.compile(r"""
262
+ \\\(|
263
+ \\\)|
264
+ \\\[|
265
+ \\]|
266
+ \\\\|
267
+ \\|
268
+ \(|
269
+ \[|
270
+ :([+-]?[.\d]+)\)|
271
+ \)|
272
+ ]|
273
+ [^\\()\[\]:]+|
274
+ :
275
+ """, re.X)
276
+
277
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
278
+
279
+ def parse_prompt_attention(text):
280
+ """
281
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
282
+ Accepted tokens are:
283
+ (abc) - increases attention to abc by a multiplier of 1.1
284
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
285
+ [abc] - decreases attention to abc by a multiplier of 1.1
286
+ \( - literal character '('
287
+ \[ - literal character '['
288
+ \) - literal character ')'
289
+ \] - literal character ']'
290
+ \\ - literal character '\'
291
+ anything else - just text
292
+
293
+ >>> parse_prompt_attention('normal text')
294
+ [['normal text', 1.0]]
295
+ >>> parse_prompt_attention('an (important) word')
296
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
297
+ >>> parse_prompt_attention('(unbalanced')
298
+ [['unbalanced', 1.1]]
299
+ >>> parse_prompt_attention('\(literal\]')
300
+ [['(literal]', 1.0]]
301
+ >>> parse_prompt_attention('(unnecessary)(parens)')
302
+ [['unnecessaryparens', 1.1]]
303
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
304
+ [['a ', 1.0],
305
+ ['house', 1.5730000000000004],
306
+ [' ', 1.1],
307
+ ['on', 1.0],
308
+ [' a ', 1.1],
309
+ ['hill', 0.55],
310
+ [', sun, ', 1.1],
311
+ ['sky', 1.4641000000000006],
312
+ ['.', 1.1]]
313
+ """
314
+
315
+ res = []
316
+ round_brackets = []
317
+ square_brackets = []
318
+
319
+ round_bracket_multiplier = 1.1
320
+ square_bracket_multiplier = 1 / 1.1
321
+
322
+ def multiply_range(start_position, multiplier):
323
+ for p in range(start_position, len(res)):
324
+ res[p][1] *= multiplier
325
+
326
+ for m in re_attention.finditer(text):
327
+ text = m.group(0)
328
+ weight = m.group(1)
329
+
330
+ if text.startswith('\\'):
331
+ res.append([text[1:], 1.0])
332
+ elif text == '(':
333
+ round_brackets.append(len(res))
334
+ elif text == '[':
335
+ square_brackets.append(len(res))
336
+ elif weight is not None and len(round_brackets) > 0:
337
+ multiply_range(round_brackets.pop(), float(weight))
338
+ elif text == ')' and len(round_brackets) > 0:
339
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
340
+ elif text == ']' and len(square_brackets) > 0:
341
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
342
+ else:
343
+ parts = re.split(re_break, text)
344
+ for i, part in enumerate(parts):
345
+ if i > 0:
346
+ res.append(["BREAK", -1])
347
+ res.append([part, 1.0])
348
+
349
+ for pos in round_brackets:
350
+ multiply_range(pos, round_bracket_multiplier)
351
+
352
+ for pos in square_brackets:
353
+ multiply_range(pos, square_bracket_multiplier)
354
+
355
+ if len(res) == 0:
356
+ res = [["", 1.0]]
357
+
358
+ # merge runs of identical weights
359
+ i = 0
360
+ while i + 1 < len(res):
361
+ if res[i][1] == res[i + 1][1]:
362
+ res[i][0] += res[i + 1][0]
363
+ res.pop(i + 1)
364
+ else:
365
+ i += 1
366
+
367
+ return res
368
+
369
+ if __name__ == "__main__":
370
+ import doctest
371
+ doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
372
+ else:
373
+ import torch # doctest faster
sd/stable-diffusion-webui/modules/realesrgan_model.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ from basicsr.utils.download_util import load_file_from_url
8
+ from realesrgan import RealESRGANer
9
+
10
+ from modules.upscaler import Upscaler, UpscalerData
11
+ from modules.shared import cmd_opts, opts
12
+
13
+
14
+ class UpscalerRealESRGAN(Upscaler):
15
+ def __init__(self, path):
16
+ self.name = "RealESRGAN"
17
+ self.user_path = path
18
+ super().__init__()
19
+ try:
20
+ from basicsr.archs.rrdbnet_arch import RRDBNet
21
+ from realesrgan import RealESRGANer
22
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
23
+ self.enable = True
24
+ self.scalers = []
25
+ scalers = self.load_models(path)
26
+ for scaler in scalers:
27
+ if scaler.name in opts.realesrgan_enabled_models:
28
+ self.scalers.append(scaler)
29
+
30
+ except Exception:
31
+ print("Error importing Real-ESRGAN:", file=sys.stderr)
32
+ print(traceback.format_exc(), file=sys.stderr)
33
+ self.enable = False
34
+ self.scalers = []
35
+
36
+ def do_upscale(self, img, path):
37
+ if not self.enable:
38
+ return img
39
+
40
+ info = self.load_model(path)
41
+ if not os.path.exists(info.local_data_path):
42
+ print("Unable to load RealESRGAN model: %s" % info.name)
43
+ return img
44
+
45
+ upsampler = RealESRGANer(
46
+ scale=info.scale,
47
+ model_path=info.local_data_path,
48
+ model=info.model(),
49
+ half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
50
+ tile=opts.ESRGAN_tile,
51
+ tile_pad=opts.ESRGAN_tile_overlap,
52
+ )
53
+
54
+ upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
55
+
56
+ image = Image.fromarray(upsampled)
57
+ return image
58
+
59
+ def load_model(self, path):
60
+ try:
61
+ info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
62
+
63
+ if info is None:
64
+ print(f"Unable to find model info: {path}")
65
+ return None
66
+
67
+ info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
68
+ return info
69
+ except Exception as e:
70
+ print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
71
+ print(traceback.format_exc(), file=sys.stderr)
72
+ return None
73
+
74
+ def load_models(self, _):
75
+ return get_realesrgan_models(self)
76
+
77
+
78
+ def get_realesrgan_models(scaler):
79
+ try:
80
+ from basicsr.archs.rrdbnet_arch import RRDBNet
81
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
82
+ models = [
83
+ UpscalerData(
84
+ name="R-ESRGAN General 4xV3",
85
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
86
+ scale=4,
87
+ upscaler=scaler,
88
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
89
+ ),
90
+ UpscalerData(
91
+ name="R-ESRGAN General WDN 4xV3",
92
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
93
+ scale=4,
94
+ upscaler=scaler,
95
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
96
+ ),
97
+ UpscalerData(
98
+ name="R-ESRGAN AnimeVideo",
99
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
100
+ scale=4,
101
+ upscaler=scaler,
102
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
103
+ ),
104
+ UpscalerData(
105
+ name="R-ESRGAN 4x+",
106
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
107
+ scale=4,
108
+ upscaler=scaler,
109
+ model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
110
+ ),
111
+ UpscalerData(
112
+ name="R-ESRGAN 4x+ Anime6B",
113
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
114
+ scale=4,
115
+ upscaler=scaler,
116
+ model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
117
+ ),
118
+ UpscalerData(
119
+ name="R-ESRGAN 2x+",
120
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
121
+ scale=2,
122
+ upscaler=scaler,
123
+ model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
124
+ ),
125
+ ]
126
+ return models
127
+ except Exception as e:
128
+ print("Error making Real-ESRGAN models list:", file=sys.stderr)
129
+ print(traceback.format_exc(), file=sys.stderr)
sd/stable-diffusion-webui/modules/safe.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this code is adapted from the script contributed by anon from /h/
2
+
3
+ import io
4
+ import pickle
5
+ import collections
6
+ import sys
7
+ import traceback
8
+
9
+ import torch
10
+ import numpy
11
+ import _codecs
12
+ import zipfile
13
+ import re
14
+
15
+
16
+ # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
17
+ TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
18
+
19
+
20
+ def encode(*args):
21
+ out = _codecs.encode(*args)
22
+ return out
23
+
24
+
25
+ class RestrictedUnpickler(pickle.Unpickler):
26
+ extra_handler = None
27
+
28
+ def persistent_load(self, saved_id):
29
+ assert saved_id[0] == 'storage'
30
+ return TypedStorage()
31
+
32
+ def find_class(self, module, name):
33
+ if self.extra_handler is not None:
34
+ res = self.extra_handler(module, name)
35
+ if res is not None:
36
+ return res
37
+
38
+ if module == 'collections' and name == 'OrderedDict':
39
+ return getattr(collections, name)
40
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
41
+ return getattr(torch._utils, name)
42
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
43
+ return getattr(torch, name)
44
+ if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
45
+ return getattr(torch.nn.modules.container, name)
46
+ if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
47
+ return getattr(numpy.core.multiarray, name)
48
+ if module == 'numpy' and name in ['dtype', 'ndarray']:
49
+ return getattr(numpy, name)
50
+ if module == '_codecs' and name == 'encode':
51
+ return encode
52
+ if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
53
+ import pytorch_lightning.callbacks
54
+ return pytorch_lightning.callbacks.model_checkpoint
55
+ if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
56
+ import pytorch_lightning.callbacks.model_checkpoint
57
+ return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
58
+ if module == "__builtin__" and name == 'set':
59
+ return set
60
+
61
+ # Forbid everything else.
62
+ raise Exception(f"global '{module}/{name}' is forbidden")
63
+
64
+
65
+ # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
66
+ allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
67
+ data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
68
+
69
+ def check_zip_filenames(filename, names):
70
+ for name in names:
71
+ if allowed_zip_names_re.match(name):
72
+ continue
73
+
74
+ raise Exception(f"bad file inside {filename}: {name}")
75
+
76
+
77
+ def check_pt(filename, extra_handler):
78
+ try:
79
+
80
+ # new pytorch format is a zip file
81
+ with zipfile.ZipFile(filename) as z:
82
+ check_zip_filenames(filename, z.namelist())
83
+
84
+ # find filename of data.pkl in zip file: '<directory name>/data.pkl'
85
+ data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
86
+ if len(data_pkl_filenames) == 0:
87
+ raise Exception(f"data.pkl not found in {filename}")
88
+ if len(data_pkl_filenames) > 1:
89
+ raise Exception(f"Multiple data.pkl found in {filename}")
90
+ with z.open(data_pkl_filenames[0]) as file:
91
+ unpickler = RestrictedUnpickler(file)
92
+ unpickler.extra_handler = extra_handler
93
+ unpickler.load()
94
+
95
+ except zipfile.BadZipfile:
96
+
97
+ # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
98
+ with open(filename, "rb") as file:
99
+ unpickler = RestrictedUnpickler(file)
100
+ unpickler.extra_handler = extra_handler
101
+ for i in range(5):
102
+ unpickler.load()
103
+
104
+
105
+ def load(filename, *args, **kwargs):
106
+ return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
107
+
108
+
109
+ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
110
+ """
111
+ this function is intended to be used by extensions that want to load models with
112
+ some extra classes in them that the usual unpickler would find suspicious.
113
+
114
+ Use the extra_handler argument to specify a function that takes module and field name as text,
115
+ and returns that field's value:
116
+
117
+ ```python
118
+ def extra(module, name):
119
+ if module == 'collections' and name == 'OrderedDict':
120
+ return collections.OrderedDict
121
+
122
+ return None
123
+
124
+ safe.load_with_extra('model.pt', extra_handler=extra)
125
+ ```
126
+
127
+ The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
128
+ definitely unsafe.
129
+ """
130
+
131
+ from modules import shared
132
+
133
+ try:
134
+ if not shared.cmd_opts.disable_safe_unpickle:
135
+ check_pt(filename, extra_handler)
136
+
137
+ except pickle.UnpicklingError:
138
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
139
+ print(traceback.format_exc(), file=sys.stderr)
140
+ print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
141
+ print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
142
+ return None
143
+
144
+ except Exception:
145
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
146
+ print(traceback.format_exc(), file=sys.stderr)
147
+ print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
148
+ print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
149
+ return None
150
+
151
+ return unsafe_torch_load(filename, *args, **kwargs)
152
+
153
+
154
+ class Extra:
155
+ """
156
+ A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
157
+ (because it's not your code making the torch.load call). The intended use is like this:
158
+
159
+ ```
160
+ import torch
161
+ from modules import safe
162
+
163
+ def handler(module, name):
164
+ if module == 'torch' and name in ['float64', 'float16']:
165
+ return getattr(torch, name)
166
+
167
+ return None
168
+
169
+ with safe.Extra(handler):
170
+ x = torch.load('model.pt')
171
+ ```
172
+ """
173
+
174
+ def __init__(self, handler):
175
+ self.handler = handler
176
+
177
+ def __enter__(self):
178
+ global global_extra_handler
179
+
180
+ assert global_extra_handler is None, 'already inside an Extra() block'
181
+ global_extra_handler = self.handler
182
+
183
+ def __exit__(self, exc_type, exc_val, exc_tb):
184
+ global global_extra_handler
185
+
186
+ global_extra_handler = None
187
+
188
+
189
+ unsafe_torch_load = torch.load
190
+ torch.load = load
191
+ global_extra_handler = None
192
+
sd/stable-diffusion-webui/modules/script_callbacks.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import traceback
3
+ from collections import namedtuple
4
+ import inspect
5
+ from typing import Optional, Dict, Any
6
+
7
+ from fastapi import FastAPI
8
+ from gradio import Blocks
9
+
10
+
11
+ def report_exception(c, job):
12
+ print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
13
+ print(traceback.format_exc(), file=sys.stderr)
14
+
15
+
16
+ class ImageSaveParams:
17
+ def __init__(self, image, p, filename, pnginfo):
18
+ self.image = image
19
+ """the PIL image itself"""
20
+
21
+ self.p = p
22
+ """p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
23
+
24
+ self.filename = filename
25
+ """name of file that the image would be saved to"""
26
+
27
+ self.pnginfo = pnginfo
28
+ """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
29
+
30
+
31
+ class CFGDenoiserParams:
32
+ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
33
+ self.x = x
34
+ """Latent image representation in the process of being denoised"""
35
+
36
+ self.image_cond = image_cond
37
+ """Conditioning image"""
38
+
39
+ self.sigma = sigma
40
+ """Current sigma noise step value"""
41
+
42
+ self.sampling_step = sampling_step
43
+ """Current Sampling step number"""
44
+
45
+ self.total_sampling_steps = total_sampling_steps
46
+ """Total number of sampling steps planned"""
47
+
48
+
49
+ class CFGDenoisedParams:
50
+ def __init__(self, x, sampling_step, total_sampling_steps):
51
+ self.x = x
52
+ """Latent image representation in the process of being denoised"""
53
+
54
+ self.sampling_step = sampling_step
55
+ """Current Sampling step number"""
56
+
57
+ self.total_sampling_steps = total_sampling_steps
58
+ """Total number of sampling steps planned"""
59
+
60
+
61
+ class UiTrainTabParams:
62
+ def __init__(self, txt2img_preview_params):
63
+ self.txt2img_preview_params = txt2img_preview_params
64
+
65
+
66
+ class ImageGridLoopParams:
67
+ def __init__(self, imgs, cols, rows):
68
+ self.imgs = imgs
69
+ self.cols = cols
70
+ self.rows = rows
71
+
72
+
73
+ ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
74
+ callback_map = dict(
75
+ callbacks_app_started=[],
76
+ callbacks_model_loaded=[],
77
+ callbacks_ui_tabs=[],
78
+ callbacks_ui_train_tabs=[],
79
+ callbacks_ui_settings=[],
80
+ callbacks_before_image_saved=[],
81
+ callbacks_image_saved=[],
82
+ callbacks_cfg_denoiser=[],
83
+ callbacks_cfg_denoised=[],
84
+ callbacks_before_component=[],
85
+ callbacks_after_component=[],
86
+ callbacks_image_grid=[],
87
+ callbacks_infotext_pasted=[],
88
+ callbacks_script_unloaded=[],
89
+ callbacks_before_ui=[],
90
+ )
91
+
92
+
93
+ def clear_callbacks():
94
+ for callback_list in callback_map.values():
95
+ callback_list.clear()
96
+
97
+
98
+ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
99
+ for c in callback_map['callbacks_app_started']:
100
+ try:
101
+ c.callback(demo, app)
102
+ except Exception:
103
+ report_exception(c, 'app_started_callback')
104
+
105
+
106
+ def model_loaded_callback(sd_model):
107
+ for c in callback_map['callbacks_model_loaded']:
108
+ try:
109
+ c.callback(sd_model)
110
+ except Exception:
111
+ report_exception(c, 'model_loaded_callback')
112
+
113
+
114
+ def ui_tabs_callback():
115
+ res = []
116
+
117
+ for c in callback_map['callbacks_ui_tabs']:
118
+ try:
119
+ res += c.callback() or []
120
+ except Exception:
121
+ report_exception(c, 'ui_tabs_callback')
122
+
123
+ return res
124
+
125
+
126
+ def ui_train_tabs_callback(params: UiTrainTabParams):
127
+ for c in callback_map['callbacks_ui_train_tabs']:
128
+ try:
129
+ c.callback(params)
130
+ except Exception:
131
+ report_exception(c, 'callbacks_ui_train_tabs')
132
+
133
+
134
+ def ui_settings_callback():
135
+ for c in callback_map['callbacks_ui_settings']:
136
+ try:
137
+ c.callback()
138
+ except Exception:
139
+ report_exception(c, 'ui_settings_callback')
140
+
141
+
142
+ def before_image_saved_callback(params: ImageSaveParams):
143
+ for c in callback_map['callbacks_before_image_saved']:
144
+ try:
145
+ c.callback(params)
146
+ except Exception:
147
+ report_exception(c, 'before_image_saved_callback')
148
+
149
+
150
+ def image_saved_callback(params: ImageSaveParams):
151
+ for c in callback_map['callbacks_image_saved']:
152
+ try:
153
+ c.callback(params)
154
+ except Exception:
155
+ report_exception(c, 'image_saved_callback')
156
+
157
+
158
+ def cfg_denoiser_callback(params: CFGDenoiserParams):
159
+ for c in callback_map['callbacks_cfg_denoiser']:
160
+ try:
161
+ c.callback(params)
162
+ except Exception:
163
+ report_exception(c, 'cfg_denoiser_callback')
164
+
165
+
166
+ def cfg_denoised_callback(params: CFGDenoisedParams):
167
+ for c in callback_map['callbacks_cfg_denoised']:
168
+ try:
169
+ c.callback(params)
170
+ except Exception:
171
+ report_exception(c, 'cfg_denoised_callback')
172
+
173
+
174
+ def before_component_callback(component, **kwargs):
175
+ for c in callback_map['callbacks_before_component']:
176
+ try:
177
+ c.callback(component, **kwargs)
178
+ except Exception:
179
+ report_exception(c, 'before_component_callback')
180
+
181
+
182
+ def after_component_callback(component, **kwargs):
183
+ for c in callback_map['callbacks_after_component']:
184
+ try:
185
+ c.callback(component, **kwargs)
186
+ except Exception:
187
+ report_exception(c, 'after_component_callback')
188
+
189
+
190
+ def image_grid_callback(params: ImageGridLoopParams):
191
+ for c in callback_map['callbacks_image_grid']:
192
+ try:
193
+ c.callback(params)
194
+ except Exception:
195
+ report_exception(c, 'image_grid')
196
+
197
+
198
+ def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
199
+ for c in callback_map['callbacks_infotext_pasted']:
200
+ try:
201
+ c.callback(infotext, params)
202
+ except Exception:
203
+ report_exception(c, 'infotext_pasted')
204
+
205
+
206
+ def script_unloaded_callback():
207
+ for c in reversed(callback_map['callbacks_script_unloaded']):
208
+ try:
209
+ c.callback()
210
+ except Exception:
211
+ report_exception(c, 'script_unloaded')
212
+
213
+
214
+ def before_ui_callback():
215
+ for c in reversed(callback_map['callbacks_before_ui']):
216
+ try:
217
+ c.callback()
218
+ except Exception:
219
+ report_exception(c, 'before_ui')
220
+
221
+
222
+ def add_callback(callbacks, fun):
223
+ stack = [x for x in inspect.stack() if x.filename != __file__]
224
+ filename = stack[0].filename if len(stack) > 0 else 'unknown file'
225
+
226
+ callbacks.append(ScriptCallback(filename, fun))
227
+
228
+
229
+ def remove_current_script_callbacks():
230
+ stack = [x for x in inspect.stack() if x.filename != __file__]
231
+ filename = stack[0].filename if len(stack) > 0 else 'unknown file'
232
+ if filename == 'unknown file':
233
+ return
234
+ for callback_list in callback_map.values():
235
+ for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
236
+ callback_list.remove(callback_to_remove)
237
+
238
+
239
+ def remove_callbacks_for_function(callback_func):
240
+ for callback_list in callback_map.values():
241
+ for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
242
+ callback_list.remove(callback_to_remove)
243
+
244
+
245
+ def on_app_started(callback):
246
+ """register a function to be called when the webui started, the gradio `Block` component and
247
+ fastapi `FastAPI` object are passed as the arguments"""
248
+ add_callback(callback_map['callbacks_app_started'], callback)
249
+
250
+
251
+ def on_model_loaded(callback):
252
+ """register a function to be called when the stable diffusion model is created; the model is
253
+ passed as an argument; this function is also called when the script is reloaded. """
254
+ add_callback(callback_map['callbacks_model_loaded'], callback)
255
+
256
+
257
+ def on_ui_tabs(callback):
258
+ """register a function to be called when the UI is creating new tabs.
259
+ The function must either return a None, which means no new tabs to be added, or a list, where
260
+ each element is a tuple:
261
+ (gradio_component, title, elem_id)
262
+
263
+ gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
264
+ title is tab text displayed to user in the UI
265
+ elem_id is HTML id for the tab
266
+ """
267
+ add_callback(callback_map['callbacks_ui_tabs'], callback)
268
+
269
+
270
+ def on_ui_train_tabs(callback):
271
+ """register a function to be called when the UI is creating new tabs for the train tab.
272
+ Create your new tabs with gr.Tab.
273
+ """
274
+ add_callback(callback_map['callbacks_ui_train_tabs'], callback)
275
+
276
+
277
+ def on_ui_settings(callback):
278
+ """register a function to be called before UI settings are populated; add your settings
279
+ by using shared.opts.add_option(shared.OptionInfo(...)) """
280
+ add_callback(callback_map['callbacks_ui_settings'], callback)
281
+
282
+
283
+ def on_before_image_saved(callback):
284
+ """register a function to be called before an image is saved to a file.
285
+ The callback is called with one argument:
286
+ - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
287
+ """
288
+ add_callback(callback_map['callbacks_before_image_saved'], callback)
289
+
290
+
291
+ def on_image_saved(callback):
292
+ """register a function to be called after an image is saved to a file.
293
+ The callback is called with one argument:
294
+ - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
295
+ """
296
+ add_callback(callback_map['callbacks_image_saved'], callback)
297
+
298
+
299
+ def on_cfg_denoiser(callback):
300
+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
301
+ The callback is called with one argument:
302
+ - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
303
+ """
304
+ add_callback(callback_map['callbacks_cfg_denoiser'], callback)
305
+
306
+
307
+ def on_cfg_denoised(callback):
308
+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
309
+ The callback is called with one argument:
310
+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
311
+ """
312
+ add_callback(callback_map['callbacks_cfg_denoised'], callback)
313
+
314
+
315
+ def on_before_component(callback):
316
+ """register a function to be called before a component is created.
317
+ The callback is called with arguments:
318
+ - component - gradio component that is about to be created.
319
+ - **kwargs - args to gradio.components.IOComponent.__init__ function
320
+
321
+ Use elem_id/label fields of kwargs to figure out which component it is.
322
+ This can be useful to inject your own components somewhere in the middle of vanilla UI.
323
+ """
324
+ add_callback(callback_map['callbacks_before_component'], callback)
325
+
326
+
327
+ def on_after_component(callback):
328
+ """register a function to be called after a component is created. See on_before_component for more."""
329
+ add_callback(callback_map['callbacks_after_component'], callback)
330
+
331
+
332
+ def on_image_grid(callback):
333
+ """register a function to be called before making an image grid.
334
+ The callback is called with one argument:
335
+ - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
336
+ """
337
+ add_callback(callback_map['callbacks_image_grid'], callback)
338
+
339
+
340
+ def on_infotext_pasted(callback):
341
+ """register a function to be called before applying an infotext.
342
+ The callback is called with two arguments:
343
+ - infotext: str - raw infotext.
344
+ - result: Dict[str, any] - parsed infotext parameters.
345
+ """
346
+ add_callback(callback_map['callbacks_infotext_pasted'], callback)
347
+
348
+
349
+ def on_script_unloaded(callback):
350
+ """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
351
+ the script did should be reverted here"""
352
+
353
+ add_callback(callback_map['callbacks_script_unloaded'], callback)
354
+
355
+
356
+ def on_before_ui(callback):
357
+ """register a function to be called before the UI is created."""
358
+
359
+ add_callback(callback_map['callbacks_before_ui'], callback)
sd/stable-diffusion-webui/modules/script_loading.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+ import importlib.util
5
+ from types import ModuleType
6
+
7
+
8
+ def load_module(path):
9
+ module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
10
+ module = importlib.util.module_from_spec(module_spec)
11
+ module_spec.loader.exec_module(module)
12
+
13
+ return module
14
+
15
+
16
+ def preload_extensions(extensions_dir, parser):
17
+ if not os.path.isdir(extensions_dir):
18
+ return
19
+
20
+ for dirname in sorted(os.listdir(extensions_dir)):
21
+ preload_script = os.path.join(extensions_dir, dirname, "preload.py")
22
+ if not os.path.isfile(preload_script):
23
+ continue
24
+
25
+ try:
26
+ module = load_module(preload_script)
27
+ if hasattr(module, 'preload'):
28
+ module.preload(parser)
29
+
30
+ except Exception:
31
+ print(f"Error running preload() for {preload_script}", file=sys.stderr)
32
+ print(traceback.format_exc(), file=sys.stderr)
sd/stable-diffusion-webui/modules/scripts.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+ import traceback
5
+ from collections import namedtuple
6
+
7
+ import gradio as gr
8
+
9
+ from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
10
+
11
+ AlwaysVisible = object()
12
+
13
+
14
+ class PostprocessImageArgs:
15
+ def __init__(self, image):
16
+ self.image = image
17
+
18
+
19
+ class Script:
20
+ filename = None
21
+ args_from = None
22
+ args_to = None
23
+ alwayson = False
24
+
25
+ is_txt2img = False
26
+ is_img2img = False
27
+
28
+ """A gr.Group component that has all script's UI inside it"""
29
+ group = None
30
+
31
+ infotext_fields = None
32
+ """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
33
+ parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
34
+ """
35
+
36
+ def title(self):
37
+ """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
38
+
39
+ raise NotImplementedError()
40
+
41
+ def ui(self, is_img2img):
42
+ """this function should create gradio UI elements. See https://gradio.app/docs/#components
43
+ The return value should be an array of all components that are used in processing.
44
+ Values of those returned components will be passed to run() and process() functions.
45
+ """
46
+
47
+ pass
48
+
49
+ def show(self, is_img2img):
50
+ """
51
+ is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
52
+
53
+ This function should return:
54
+ - False if the script should not be shown in UI at all
55
+ - True if the script should be shown in UI if it's selected in the scripts dropdown
56
+ - script.AlwaysVisible if the script should be shown in UI at all times
57
+ """
58
+
59
+ return True
60
+
61
+ def run(self, p, *args):
62
+ """
63
+ This function is called if the script has been selected in the script dropdown.
64
+ It must do all processing and return the Processed object with results, same as
65
+ one returned by processing.process_images.
66
+
67
+ Usually the processing is done by calling the processing.process_images function.
68
+
69
+ args contains all values returned by components from ui()
70
+ """
71
+
72
+ pass
73
+
74
+ def process(self, p, *args):
75
+ """
76
+ This function is called before processing begins for AlwaysVisible scripts.
77
+ You can modify the processing object (p) here, inject hooks, etc.
78
+ args contains all values returned by components from ui()
79
+ """
80
+
81
+ pass
82
+
83
+ def process_batch(self, p, *args, **kwargs):
84
+ """
85
+ Same as process(), but called for every batch.
86
+
87
+ **kwargs will have those items:
88
+ - batch_number - index of current batch, from 0 to number of batches-1
89
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
90
+ - seeds - list of seeds for current batch
91
+ - subseeds - list of subseeds for current batch
92
+ """
93
+
94
+ pass
95
+
96
+ def postprocess_batch(self, p, *args, **kwargs):
97
+ """
98
+ Same as process_batch(), but called for every batch after it has been generated.
99
+
100
+ **kwargs will have same items as process_batch, and also:
101
+ - batch_number - index of current batch, from 0 to number of batches-1
102
+ - images - torch tensor with all generated images, with values ranging from 0 to 1;
103
+ """
104
+
105
+ pass
106
+
107
+ def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
108
+ """
109
+ Called for every image after it has been generated.
110
+ """
111
+
112
+ pass
113
+
114
+ def postprocess(self, p, processed, *args):
115
+ """
116
+ This function is called after processing ends for AlwaysVisible scripts.
117
+ args contains all values returned by components from ui()
118
+ """
119
+
120
+ pass
121
+
122
+ def before_component(self, component, **kwargs):
123
+ """
124
+ Called before a component is created.
125
+ Use elem_id/label fields of kwargs to figure out which component it is.
126
+ This can be useful to inject your own components somewhere in the middle of vanilla UI.
127
+ You can return created components in the ui() function to add them to the list of arguments for your processing functions
128
+ """
129
+
130
+ pass
131
+
132
+ def after_component(self, component, **kwargs):
133
+ """
134
+ Called after a component is created. Same as above.
135
+ """
136
+
137
+ pass
138
+
139
+ def describe(self):
140
+ """unused"""
141
+ return ""
142
+
143
+ def elem_id(self, item_id):
144
+ """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
145
+
146
+ need_tabname = self.show(True) == self.show(False)
147
+ tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
148
+ title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
149
+
150
+ return f'script_{tabname}{title}_{item_id}'
151
+
152
+
153
+ current_basedir = paths.script_path
154
+
155
+
156
+ def basedir():
157
+ """returns the base directory for the current script. For scripts in the main scripts directory,
158
+ this is the main directory (where webui.py resides), and for scripts in extensions directory
159
+ (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
160
+ """
161
+ return current_basedir
162
+
163
+
164
+ ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
165
+
166
+ scripts_data = []
167
+ postprocessing_scripts_data = []
168
+ ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
169
+
170
+
171
+ def list_scripts(scriptdirname, extension):
172
+ scripts_list = []
173
+
174
+ basedir = os.path.join(paths.script_path, scriptdirname)
175
+ if os.path.exists(basedir):
176
+ for filename in sorted(os.listdir(basedir)):
177
+ scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
178
+
179
+ for ext in extensions.active():
180
+ scripts_list += ext.list_files(scriptdirname, extension)
181
+
182
+ scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
183
+
184
+ return scripts_list
185
+
186
+
187
+ def list_files_with_name(filename):
188
+ res = []
189
+
190
+ dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
191
+
192
+ for dirpath in dirs:
193
+ if not os.path.isdir(dirpath):
194
+ continue
195
+
196
+ path = os.path.join(dirpath, filename)
197
+ if os.path.isfile(path):
198
+ res.append(path)
199
+
200
+ return res
201
+
202
+
203
+ def load_scripts():
204
+ global current_basedir
205
+ scripts_data.clear()
206
+ postprocessing_scripts_data.clear()
207
+ script_callbacks.clear_callbacks()
208
+
209
+ scripts_list = list_scripts("scripts", ".py")
210
+
211
+ syspath = sys.path
212
+
213
+ def register_scripts_from_module(module):
214
+ for key, script_class in module.__dict__.items():
215
+ if type(script_class) != type:
216
+ continue
217
+
218
+ if issubclass(script_class, Script):
219
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
220
+ elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
221
+ postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
222
+
223
+ for scriptfile in sorted(scripts_list):
224
+ try:
225
+ if scriptfile.basedir != paths.script_path:
226
+ sys.path = [scriptfile.basedir] + sys.path
227
+ current_basedir = scriptfile.basedir
228
+
229
+ script_module = script_loading.load_module(scriptfile.path)
230
+ register_scripts_from_module(script_module)
231
+
232
+ except Exception:
233
+ print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
234
+ print(traceback.format_exc(), file=sys.stderr)
235
+
236
+ finally:
237
+ sys.path = syspath
238
+ current_basedir = paths.script_path
239
+
240
+
241
+ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
242
+ try:
243
+ res = func(*args, **kwargs)
244
+ return res
245
+ except Exception:
246
+ print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
247
+ print(traceback.format_exc(), file=sys.stderr)
248
+
249
+ return default
250
+
251
+
252
+ class ScriptRunner:
253
+ def __init__(self):
254
+ self.scripts = []
255
+ self.selectable_scripts = []
256
+ self.alwayson_scripts = []
257
+ self.titles = []
258
+ self.infotext_fields = []
259
+
260
+ def initialize_scripts(self, is_img2img):
261
+ from modules import scripts_auto_postprocessing
262
+
263
+ self.scripts.clear()
264
+ self.alwayson_scripts.clear()
265
+ self.selectable_scripts.clear()
266
+
267
+ auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
268
+
269
+ for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
270
+ script = script_class()
271
+ script.filename = path
272
+ script.is_txt2img = not is_img2img
273
+ script.is_img2img = is_img2img
274
+
275
+ visibility = script.show(script.is_img2img)
276
+
277
+ if visibility == AlwaysVisible:
278
+ self.scripts.append(script)
279
+ self.alwayson_scripts.append(script)
280
+ script.alwayson = True
281
+
282
+ elif visibility:
283
+ self.scripts.append(script)
284
+ self.selectable_scripts.append(script)
285
+
286
+ def setup_ui(self):
287
+ self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
288
+
289
+ inputs = [None]
290
+ inputs_alwayson = [True]
291
+
292
+ def create_script_ui(script, inputs, inputs_alwayson):
293
+ script.args_from = len(inputs)
294
+ script.args_to = len(inputs)
295
+
296
+ controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
297
+
298
+ if controls is None:
299
+ return
300
+
301
+ for control in controls:
302
+ control.custom_script_source = os.path.basename(script.filename)
303
+
304
+ if script.infotext_fields is not None:
305
+ self.infotext_fields += script.infotext_fields
306
+
307
+ inputs += controls
308
+ inputs_alwayson += [script.alwayson for _ in controls]
309
+ script.args_to = len(inputs)
310
+
311
+ for script in self.alwayson_scripts:
312
+ with gr.Group() as group:
313
+ create_script_ui(script, inputs, inputs_alwayson)
314
+
315
+ script.group = group
316
+
317
+ dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
318
+ inputs[0] = dropdown
319
+
320
+ for script in self.selectable_scripts:
321
+ with gr.Group(visible=False) as group:
322
+ create_script_ui(script, inputs, inputs_alwayson)
323
+
324
+ script.group = group
325
+
326
+ def select_script(script_index):
327
+ selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
328
+
329
+ return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
330
+
331
+ def init_field(title):
332
+ """called when an initial value is set from ui-config.json to show script's UI components"""
333
+
334
+ if title == 'None':
335
+ return
336
+
337
+ script_index = self.titles.index(title)
338
+ self.selectable_scripts[script_index].group.visible = True
339
+
340
+ dropdown.init_field = init_field
341
+
342
+ dropdown.change(
343
+ fn=select_script,
344
+ inputs=[dropdown],
345
+ outputs=[script.group for script in self.selectable_scripts]
346
+ )
347
+
348
+ self.script_load_ctr = 0
349
+ def onload_script_visibility(params):
350
+ title = params.get('Script', None)
351
+ if title:
352
+ title_index = self.titles.index(title)
353
+ visibility = title_index == self.script_load_ctr
354
+ self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
355
+ return gr.update(visible=visibility)
356
+ else:
357
+ return gr.update(visible=False)
358
+
359
+ self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
360
+ self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
361
+
362
+ return inputs
363
+
364
+ def run(self, p, *args):
365
+ script_index = args[0]
366
+
367
+ if script_index == 0:
368
+ return None
369
+
370
+ script = self.selectable_scripts[script_index-1]
371
+
372
+ if script is None:
373
+ return None
374
+
375
+ script_args = args[script.args_from:script.args_to]
376
+ processed = script.run(p, *script_args)
377
+
378
+ shared.total_tqdm.clear()
379
+
380
+ return processed
381
+
382
+ def process(self, p):
383
+ for script in self.alwayson_scripts:
384
+ try:
385
+ script_args = p.script_args[script.args_from:script.args_to]
386
+ script.process(p, *script_args)
387
+ except Exception:
388
+ print(f"Error running process: {script.filename}", file=sys.stderr)
389
+ print(traceback.format_exc(), file=sys.stderr)
390
+
391
+ def process_batch(self, p, **kwargs):
392
+ for script in self.alwayson_scripts:
393
+ try:
394
+ script_args = p.script_args[script.args_from:script.args_to]
395
+ script.process_batch(p, *script_args, **kwargs)
396
+ except Exception:
397
+ print(f"Error running process_batch: {script.filename}", file=sys.stderr)
398
+ print(traceback.format_exc(), file=sys.stderr)
399
+
400
+ def postprocess(self, p, processed):
401
+ for script in self.alwayson_scripts:
402
+ try:
403
+ script_args = p.script_args[script.args_from:script.args_to]
404
+ script.postprocess(p, processed, *script_args)
405
+ except Exception:
406
+ print(f"Error running postprocess: {script.filename}", file=sys.stderr)
407
+ print(traceback.format_exc(), file=sys.stderr)
408
+
409
+ def postprocess_batch(self, p, images, **kwargs):
410
+ for script in self.alwayson_scripts:
411
+ try:
412
+ script_args = p.script_args[script.args_from:script.args_to]
413
+ script.postprocess_batch(p, *script_args, images=images, **kwargs)
414
+ except Exception:
415
+ print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
416
+ print(traceback.format_exc(), file=sys.stderr)
417
+
418
+ def postprocess_image(self, p, pp: PostprocessImageArgs):
419
+ for script in self.alwayson_scripts:
420
+ try:
421
+ script_args = p.script_args[script.args_from:script.args_to]
422
+ script.postprocess_image(p, pp, *script_args)
423
+ except Exception:
424
+ print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
425
+ print(traceback.format_exc(), file=sys.stderr)
426
+
427
+ def before_component(self, component, **kwargs):
428
+ for script in self.scripts:
429
+ try:
430
+ script.before_component(component, **kwargs)
431
+ except Exception:
432
+ print(f"Error running before_component: {script.filename}", file=sys.stderr)
433
+ print(traceback.format_exc(), file=sys.stderr)
434
+
435
+ def after_component(self, component, **kwargs):
436
+ for script in self.scripts:
437
+ try:
438
+ script.after_component(component, **kwargs)
439
+ except Exception:
440
+ print(f"Error running after_component: {script.filename}", file=sys.stderr)
441
+ print(traceback.format_exc(), file=sys.stderr)
442
+
443
+ def reload_sources(self, cache):
444
+ for si, script in list(enumerate(self.scripts)):
445
+ args_from = script.args_from
446
+ args_to = script.args_to
447
+ filename = script.filename
448
+
449
+ module = cache.get(filename, None)
450
+ if module is None:
451
+ module = script_loading.load_module(script.filename)
452
+ cache[filename] = module
453
+
454
+ for key, script_class in module.__dict__.items():
455
+ if type(script_class) == type and issubclass(script_class, Script):
456
+ self.scripts[si] = script_class()
457
+ self.scripts[si].filename = filename
458
+ self.scripts[si].args_from = args_from
459
+ self.scripts[si].args_to = args_to
460
+
461
+
462
+ scripts_txt2img = ScriptRunner()
463
+ scripts_img2img = ScriptRunner()
464
+ scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
465
+ scripts_current: ScriptRunner = None
466
+
467
+
468
+ def reload_script_body_only():
469
+ cache = {}
470
+ scripts_txt2img.reload_sources(cache)
471
+ scripts_img2img.reload_sources(cache)
472
+
473
+
474
+ def reload_scripts():
475
+ global scripts_txt2img, scripts_img2img, scripts_postproc
476
+
477
+ load_scripts()
478
+
479
+ scripts_txt2img = ScriptRunner()
480
+ scripts_img2img = ScriptRunner()
481
+ scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
482
+
483
+
484
+ def IOComponent_init(self, *args, **kwargs):
485
+ if scripts_current is not None:
486
+ scripts_current.before_component(self, **kwargs)
487
+
488
+ script_callbacks.before_component_callback(self, **kwargs)
489
+
490
+ res = original_IOComponent_init(self, *args, **kwargs)
491
+
492
+ script_callbacks.after_component_callback(self, **kwargs)
493
+
494
+ if scripts_current is not None:
495
+ scripts_current.after_component(self, **kwargs)
496
+
497
+ return res
498
+
499
+
500
+ original_IOComponent_init = gr.components.IOComponent.__init__
501
+ gr.components.IOComponent.__init__ = IOComponent_init
sd/stable-diffusion-webui/modules/scripts_auto_postprocessing.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import scripts, scripts_postprocessing, shared
2
+
3
+
4
+ class ScriptPostprocessingForMainUI(scripts.Script):
5
+ def __init__(self, script_postproc):
6
+ self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
7
+ self.postprocessing_controls = None
8
+
9
+ def title(self):
10
+ return self.script.name
11
+
12
+ def show(self, is_img2img):
13
+ return scripts.AlwaysVisible
14
+
15
+ def ui(self, is_img2img):
16
+ self.postprocessing_controls = self.script.ui()
17
+ return self.postprocessing_controls.values()
18
+
19
+ def postprocess_image(self, p, script_pp, *args):
20
+ args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
21
+
22
+ pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
23
+ pp.info = {}
24
+ self.script.process(pp, **args_dict)
25
+ p.extra_generation_params.update(pp.info)
26
+ script_pp.image = pp.image
27
+
28
+
29
+ def create_auto_preprocessing_script_data():
30
+ from modules import scripts
31
+
32
+ res = []
33
+
34
+ for name in shared.opts.postprocessing_enable_in_main_ui:
35
+ script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
36
+ if script is None:
37
+ continue
38
+
39
+ constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
40
+ res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
41
+
42
+ return res
sd/stable-diffusion-webui/modules/scripts_postprocessing.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+
4
+ from modules import errors, shared
5
+
6
+
7
+ class PostprocessedImage:
8
+ def __init__(self, image):
9
+ self.image = image
10
+ self.info = {}
11
+
12
+
13
+ class ScriptPostprocessing:
14
+ filename = None
15
+ controls = None
16
+ args_from = None
17
+ args_to = None
18
+
19
+ order = 1000
20
+ """scripts will be ordred by this value in postprocessing UI"""
21
+
22
+ name = None
23
+ """this function should return the title of the script."""
24
+
25
+ group = None
26
+ """A gr.Group component that has all script's UI inside it"""
27
+
28
+ def ui(self):
29
+ """
30
+ This function should create gradio UI elements. See https://gradio.app/docs/#components
31
+ The return value should be a dictionary that maps parameter names to components used in processing.
32
+ Values of those components will be passed to process() function.
33
+ """
34
+
35
+ pass
36
+
37
+ def process(self, pp: PostprocessedImage, **args):
38
+ """
39
+ This function is called to postprocess the image.
40
+ args contains a dictionary with all values returned by components from ui()
41
+ """
42
+
43
+ pass
44
+
45
+ def image_changed(self):
46
+ pass
47
+
48
+
49
+
50
+
51
+ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
52
+ try:
53
+ res = func(*args, **kwargs)
54
+ return res
55
+ except Exception as e:
56
+ errors.display(e, f"calling {filename}/{funcname}")
57
+
58
+ return default
59
+
60
+
61
+ class ScriptPostprocessingRunner:
62
+ def __init__(self):
63
+ self.scripts = None
64
+ self.ui_created = False
65
+
66
+ def initialize_scripts(self, scripts_data):
67
+ self.scripts = []
68
+
69
+ for script_class, path, basedir, script_module in scripts_data:
70
+ script: ScriptPostprocessing = script_class()
71
+ script.filename = path
72
+
73
+ if script.name == "Simple Upscale":
74
+ continue
75
+
76
+ self.scripts.append(script)
77
+
78
+ def create_script_ui(self, script, inputs):
79
+ script.args_from = len(inputs)
80
+ script.args_to = len(inputs)
81
+
82
+ script.controls = wrap_call(script.ui, script.filename, "ui")
83
+
84
+ for control in script.controls.values():
85
+ control.custom_script_source = os.path.basename(script.filename)
86
+
87
+ inputs += list(script.controls.values())
88
+ script.args_to = len(inputs)
89
+
90
+ def scripts_in_preferred_order(self):
91
+ if self.scripts is None:
92
+ import modules.scripts
93
+ self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
94
+
95
+ scripts_order = shared.opts.postprocessing_operation_order
96
+
97
+ def script_score(name):
98
+ for i, possible_match in enumerate(scripts_order):
99
+ if possible_match == name:
100
+ return i
101
+
102
+ return len(self.scripts)
103
+
104
+ script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
105
+
106
+ return sorted(self.scripts, key=lambda x: script_scores[x.name])
107
+
108
+ def setup_ui(self):
109
+ inputs = []
110
+
111
+ for script in self.scripts_in_preferred_order():
112
+ with gr.Box() as group:
113
+ self.create_script_ui(script, inputs)
114
+
115
+ script.group = group
116
+
117
+ self.ui_created = True
118
+ return inputs
119
+
120
+ def run(self, pp: PostprocessedImage, args):
121
+ for script in self.scripts_in_preferred_order():
122
+ shared.state.job = script.name
123
+
124
+ script_args = args[script.args_from:script.args_to]
125
+
126
+ process_args = {}
127
+ for (name, component), value in zip(script.controls.items(), script_args):
128
+ process_args[name] = value
129
+
130
+ script.process(pp, **process_args)
131
+
132
+ def create_args_for_run(self, scripts_args):
133
+ if not self.ui_created:
134
+ with gr.Blocks(analytics_enabled=False):
135
+ self.setup_ui()
136
+
137
+ scripts = self.scripts_in_preferred_order()
138
+ args = [None] * max([x.args_to for x in scripts])
139
+
140
+ for script in scripts:
141
+ script_args_dict = scripts_args.get(script.name, None)
142
+ if script_args_dict is not None:
143
+
144
+ for i, name in enumerate(script.controls):
145
+ args[script.args_from + i] = script_args_dict.get(name, None)
146
+
147
+ return args
148
+
149
+ def image_changed(self):
150
+ for script in self.scripts_in_preferred_order():
151
+ script.image_changed()
152
+
sd/stable-diffusion-webui/modules/sd_disable_initialization.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ldm.modules.encoders.modules
2
+ import open_clip
3
+ import torch
4
+ import transformers.utils.hub
5
+
6
+
7
+ class DisableInitialization:
8
+ """
9
+ When an object of this class enters a `with` block, it starts:
10
+ - preventing torch's layer initialization functions from working
11
+ - changes CLIP and OpenCLIP to not download model weights
12
+ - changes CLIP to not make requests to check if there is a new version of a file you already have
13
+
14
+ When it leaves the block, it reverts everything to how it was before.
15
+
16
+ Use it like this:
17
+ ```
18
+ with DisableInitialization():
19
+ do_things()
20
+ ```
21
+ """
22
+
23
+ def __init__(self, disable_clip=True):
24
+ self.replaced = []
25
+ self.disable_clip = disable_clip
26
+
27
+ def replace(self, obj, field, func):
28
+ original = getattr(obj, field, None)
29
+ if original is None:
30
+ return None
31
+
32
+ self.replaced.append((obj, field, original))
33
+ setattr(obj, field, func)
34
+
35
+ return original
36
+
37
+ def __enter__(self):
38
+ def do_nothing(*args, **kwargs):
39
+ pass
40
+
41
+ def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
42
+ return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
43
+
44
+ def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
45
+ res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
46
+ res.name_or_path = pretrained_model_name_or_path
47
+ return res
48
+
49
+ def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
50
+ args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
51
+ return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
52
+
53
+ def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
54
+
55
+ # this file is always 404, prevent making request
56
+ if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
57
+ return None
58
+
59
+ try:
60
+ res = original(url, *args, local_files_only=True, **kwargs)
61
+ if res is None:
62
+ res = original(url, *args, local_files_only=False, **kwargs)
63
+ return res
64
+ except Exception as e:
65
+ return original(url, *args, local_files_only=False, **kwargs)
66
+
67
+ def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
68
+ return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
69
+
70
+ def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
71
+ return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
72
+
73
+ def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
74
+ return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
75
+
76
+ self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
77
+ self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
78
+ self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
79
+
80
+ if self.disable_clip:
81
+ self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
82
+ self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
83
+ self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
84
+ self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
85
+ self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
86
+ self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
87
+
88
+ def __exit__(self, exc_type, exc_val, exc_tb):
89
+ for obj, field, original in self.replaced:
90
+ setattr(obj, field, original)
91
+
92
+ self.replaced.clear()
93
+
sd/stable-diffusion-webui/modules/sd_hijack.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.functional import silu
3
+ from types import MethodType
4
+
5
+ import modules.textual_inversion.textual_inversion
6
+ from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
7
+ from modules.hypernetworks import hypernetwork
8
+ from modules.shared import cmd_opts
9
+ from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
10
+
11
+ import ldm.modules.attention
12
+ import ldm.modules.diffusionmodules.model
13
+ import ldm.modules.diffusionmodules.openaimodel
14
+ import ldm.models.diffusion.ddim
15
+ import ldm.models.diffusion.plms
16
+ import ldm.modules.encoders.modules
17
+
18
+ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
19
+ diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
20
+ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
21
+
22
+ # new memory efficient cross attention blocks do not support hypernets and we already
23
+ # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
24
+ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
25
+ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
26
+
27
+ # silence new console spam from SD2
28
+ ldm.modules.attention.print = lambda *args: None
29
+ ldm.modules.diffusionmodules.model.print = lambda *args: None
30
+
31
+
32
+ def apply_optimizations():
33
+ undo_optimizations()
34
+
35
+ ldm.modules.diffusionmodules.model.nonlinearity = silu
36
+ ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
37
+
38
+ optimization_method = None
39
+
40
+ if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
41
+ print("Applying xformers cross attention optimization.")
42
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
43
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
44
+ optimization_method = 'xformers'
45
+ elif cmd_opts.opt_sub_quad_attention:
46
+ print("Applying sub-quadratic cross attention optimization.")
47
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
48
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
49
+ optimization_method = 'sub-quadratic'
50
+ elif cmd_opts.opt_split_attention_v1:
51
+ print("Applying v1 cross attention optimization.")
52
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
53
+ optimization_method = 'V1'
54
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
55
+ print("Applying cross attention optimization (InvokeAI).")
56
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
57
+ optimization_method = 'InvokeAI'
58
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
59
+ print("Applying cross attention optimization (Doggettx).")
60
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
61
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
62
+ optimization_method = 'Doggettx'
63
+
64
+ return optimization_method
65
+
66
+
67
+ def undo_optimizations():
68
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
69
+ ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
70
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
71
+
72
+
73
+ def fix_checkpoint():
74
+ """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
75
+ checkpoints to be added when not training (there's a warning)"""
76
+
77
+ pass
78
+
79
+
80
+ def weighted_loss(sd_model, pred, target, mean=True):
81
+ #Calculate the weight normally, but ignore the mean
82
+ loss = sd_model._old_get_loss(pred, target, mean=False)
83
+
84
+ #Check if we have weights available
85
+ weight = getattr(sd_model, '_custom_loss_weight', None)
86
+ if weight is not None:
87
+ loss *= weight
88
+
89
+ #Return the loss, as mean if specified
90
+ return loss.mean() if mean else loss
91
+
92
+ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
93
+ try:
94
+ #Temporarily append weights to a place accessible during loss calc
95
+ sd_model._custom_loss_weight = w
96
+
97
+ #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
98
+ #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
99
+ if not hasattr(sd_model, '_old_get_loss'):
100
+ sd_model._old_get_loss = sd_model.get_loss
101
+ sd_model.get_loss = MethodType(weighted_loss, sd_model)
102
+
103
+ #Run the standard forward function, but with the patched 'get_loss'
104
+ return sd_model.forward(x, c, *args, **kwargs)
105
+ finally:
106
+ try:
107
+ #Delete temporary weights if appended
108
+ del sd_model._custom_loss_weight
109
+ except AttributeError as e:
110
+ pass
111
+
112
+ #If we have an old loss function, reset the loss function to the original one
113
+ if hasattr(sd_model, '_old_get_loss'):
114
+ sd_model.get_loss = sd_model._old_get_loss
115
+ del sd_model._old_get_loss
116
+
117
+ def apply_weighted_forward(sd_model):
118
+ #Add new function 'weighted_forward' that can be called to calc weighted loss
119
+ sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
120
+
121
+ def undo_weighted_forward(sd_model):
122
+ try:
123
+ del sd_model.weighted_forward
124
+ except AttributeError as e:
125
+ pass
126
+
127
+
128
+ class StableDiffusionModelHijack:
129
+ fixes = None
130
+ comments = []
131
+ layers = None
132
+ circular_enabled = False
133
+ clip = None
134
+ optimization_method = None
135
+
136
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
137
+
138
+ def __init__(self):
139
+ self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
140
+
141
+ def hijack(self, m):
142
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
143
+ model_embeddings = m.cond_stage_model.roberta.embeddings
144
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
145
+ m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
146
+
147
+ elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
148
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
149
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
150
+ m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
151
+
152
+ elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
153
+ m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
154
+ m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
155
+
156
+ apply_weighted_forward(m)
157
+ if m.cond_stage_key == "edit":
158
+ sd_hijack_unet.hijack_ddpm_edit()
159
+
160
+ self.optimization_method = apply_optimizations()
161
+
162
+ self.clip = m.cond_stage_model
163
+
164
+ def flatten(el):
165
+ flattened = [flatten(children) for children in el.children()]
166
+ res = [el]
167
+ for c in flattened:
168
+ res += c
169
+ return res
170
+
171
+ self.layers = flatten(m)
172
+
173
+ def undo_hijack(self, m):
174
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
175
+ m.cond_stage_model = m.cond_stage_model.wrapped
176
+
177
+ elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
178
+ m.cond_stage_model = m.cond_stage_model.wrapped
179
+
180
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
181
+ if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
182
+ model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
183
+ elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
184
+ m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
185
+ m.cond_stage_model = m.cond_stage_model.wrapped
186
+
187
+ undo_optimizations()
188
+ undo_weighted_forward(m)
189
+
190
+ self.apply_circular(False)
191
+ self.layers = None
192
+ self.clip = None
193
+
194
+ def apply_circular(self, enable):
195
+ if self.circular_enabled == enable:
196
+ return
197
+
198
+ self.circular_enabled = enable
199
+
200
+ for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
201
+ layer.padding_mode = 'circular' if enable else 'zeros'
202
+
203
+ def clear_comments(self):
204
+ self.comments = []
205
+
206
+ def get_prompt_lengths(self, text):
207
+ _, token_count = self.clip.process_texts([text])
208
+
209
+ return token_count, self.clip.get_target_prompt_token_count(token_count)
210
+
211
+
212
+ class EmbeddingsWithFixes(torch.nn.Module):
213
+ def __init__(self, wrapped, embeddings):
214
+ super().__init__()
215
+ self.wrapped = wrapped
216
+ self.embeddings = embeddings
217
+
218
+ def forward(self, input_ids):
219
+ batch_fixes = self.embeddings.fixes
220
+ self.embeddings.fixes = None
221
+
222
+ inputs_embeds = self.wrapped(input_ids)
223
+
224
+ if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
225
+ return inputs_embeds
226
+
227
+ vecs = []
228
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
229
+ for offset, embedding in fixes:
230
+ emb = devices.cond_cast_unet(embedding.vec)
231
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
232
+ tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
233
+
234
+ vecs.append(tensor)
235
+
236
+ return torch.stack(vecs)
237
+
238
+
239
+ def add_circular_option_to_conv_2d():
240
+ conv2d_constructor = torch.nn.Conv2d.__init__
241
+
242
+ def conv2d_constructor_circular(self, *args, **kwargs):
243
+ return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
244
+
245
+ torch.nn.Conv2d.__init__ = conv2d_constructor_circular
246
+
247
+
248
+ model_hijack = StableDiffusionModelHijack()
249
+
250
+
251
+ def register_buffer(self, name, attr):
252
+ """
253
+ Fix register buffer bug for Mac OS.
254
+ """
255
+
256
+ if type(attr) == torch.Tensor:
257
+ if attr.device != devices.device:
258
+ attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
259
+
260
+ setattr(self, name, attr)
261
+
262
+
263
+ ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
264
+ ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
sd/stable-diffusion-webui/modules/sd_hijack_checkpoint.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.checkpoint import checkpoint
2
+
3
+ import ldm.modules.attention
4
+ import ldm.modules.diffusionmodules.openaimodel
5
+
6
+
7
+ def BasicTransformerBlock_forward(self, x, context=None):
8
+ return checkpoint(self._forward, x, context)
9
+
10
+
11
+ def AttentionBlock_forward(self, x):
12
+ return checkpoint(self._forward, x)
13
+
14
+
15
+ def ResBlock_forward(self, x, emb):
16
+ return checkpoint(self._forward, x, emb)
17
+
18
+
19
+ stored = []
20
+
21
+
22
+ def add():
23
+ if len(stored) != 0:
24
+ return
25
+
26
+ stored.extend([
27
+ ldm.modules.attention.BasicTransformerBlock.forward,
28
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
29
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
30
+ ])
31
+
32
+ ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
33
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
34
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
35
+
36
+
37
+ def remove():
38
+ if len(stored) == 0:
39
+ return
40
+
41
+ ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
42
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
43
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
44
+
45
+ stored.clear()
46
+
sd/stable-diffusion-webui/modules/sd_hijack_clip.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import namedtuple
3
+
4
+ import torch
5
+
6
+ from modules import prompt_parser, devices, sd_hijack
7
+ from modules.shared import opts
8
+
9
+
10
+ class PromptChunk:
11
+ """
12
+ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
13
+ If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
14
+ Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
15
+ so just 75 tokens from prompt.
16
+ """
17
+
18
+ def __init__(self):
19
+ self.tokens = []
20
+ self.multipliers = []
21
+ self.fixes = []
22
+
23
+
24
+ PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
25
+ """An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
26
+ chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
27
+ are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
28
+
29
+
30
+ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
31
+ """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
32
+ have unlimited prompt length and assign weights to tokens in prompt.
33
+ """
34
+
35
+ def __init__(self, wrapped, hijack):
36
+ super().__init__()
37
+
38
+ self.wrapped = wrapped
39
+ """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
40
+ depending on model."""
41
+
42
+ self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
43
+ self.chunk_length = 75
44
+
45
+ def empty_chunk(self):
46
+ """creates an empty PromptChunk and returns it"""
47
+
48
+ chunk = PromptChunk()
49
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
50
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
51
+ return chunk
52
+
53
+ def get_target_prompt_token_count(self, token_count):
54
+ """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
55
+
56
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
57
+
58
+ def tokenize(self, texts):
59
+ """Converts a batch of texts into a batch of token ids"""
60
+
61
+ raise NotImplementedError
62
+
63
+ def encode_with_transformers(self, tokens):
64
+ """
65
+ converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
66
+ All python lists with tokens are assumed to have same length, usually 77.
67
+ if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
68
+ model - can be 768 and 1024.
69
+ Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
70
+ """
71
+
72
+ raise NotImplementedError
73
+
74
+ def encode_embedding_init_text(self, init_text, nvpt):
75
+ """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
76
+ transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
77
+
78
+ raise NotImplementedError
79
+
80
+ def tokenize_line(self, line):
81
+ """
82
+ this transforms a single prompt into a list of PromptChunk objects - as many as needed to
83
+ represent the prompt.
84
+ Returns the list and the total number of tokens in the prompt.
85
+ """
86
+
87
+ if opts.enable_emphasis:
88
+ parsed = prompt_parser.parse_prompt_attention(line)
89
+ else:
90
+ parsed = [[line, 1.0]]
91
+
92
+ tokenized = self.tokenize([text for text, _ in parsed])
93
+
94
+ chunks = []
95
+ chunk = PromptChunk()
96
+ token_count = 0
97
+ last_comma = -1
98
+
99
+ def next_chunk(is_last=False):
100
+ """puts current chunk into the list of results and produces the next one - empty;
101
+ if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
102
+ nonlocal token_count
103
+ nonlocal last_comma
104
+ nonlocal chunk
105
+
106
+ if is_last:
107
+ token_count += len(chunk.tokens)
108
+ else:
109
+ token_count += self.chunk_length
110
+
111
+ to_add = self.chunk_length - len(chunk.tokens)
112
+ if to_add > 0:
113
+ chunk.tokens += [self.id_end] * to_add
114
+ chunk.multipliers += [1.0] * to_add
115
+
116
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
117
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
118
+
119
+ last_comma = -1
120
+ chunks.append(chunk)
121
+ chunk = PromptChunk()
122
+
123
+ for tokens, (text, weight) in zip(tokenized, parsed):
124
+ if text == 'BREAK' and weight == -1:
125
+ next_chunk()
126
+ continue
127
+
128
+ position = 0
129
+ while position < len(tokens):
130
+ token = tokens[position]
131
+
132
+ if token == self.comma_token:
133
+ last_comma = len(chunk.tokens)
134
+
135
+ # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
136
+ # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
137
+ elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
138
+ break_location = last_comma + 1
139
+
140
+ reloc_tokens = chunk.tokens[break_location:]
141
+ reloc_mults = chunk.multipliers[break_location:]
142
+
143
+ chunk.tokens = chunk.tokens[:break_location]
144
+ chunk.multipliers = chunk.multipliers[:break_location]
145
+
146
+ next_chunk()
147
+ chunk.tokens = reloc_tokens
148
+ chunk.multipliers = reloc_mults
149
+
150
+ if len(chunk.tokens) == self.chunk_length:
151
+ next_chunk()
152
+
153
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
154
+ if embedding is None:
155
+ chunk.tokens.append(token)
156
+ chunk.multipliers.append(weight)
157
+ position += 1
158
+ continue
159
+
160
+ emb_len = int(embedding.vec.shape[0])
161
+ if len(chunk.tokens) + emb_len > self.chunk_length:
162
+ next_chunk()
163
+
164
+ chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
165
+
166
+ chunk.tokens += [0] * emb_len
167
+ chunk.multipliers += [weight] * emb_len
168
+ position += embedding_length_in_tokens
169
+
170
+ if len(chunk.tokens) > 0 or len(chunks) == 0:
171
+ next_chunk(is_last=True)
172
+
173
+ return chunks, token_count
174
+
175
+ def process_texts(self, texts):
176
+ """
177
+ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
178
+ length, in tokens, of all texts.
179
+ """
180
+
181
+ token_count = 0
182
+
183
+ cache = {}
184
+ batch_chunks = []
185
+ for line in texts:
186
+ if line in cache:
187
+ chunks = cache[line]
188
+ else:
189
+ chunks, current_token_count = self.tokenize_line(line)
190
+ token_count = max(current_token_count, token_count)
191
+
192
+ cache[line] = chunks
193
+
194
+ batch_chunks.append(chunks)
195
+
196
+ return batch_chunks, token_count
197
+
198
+ def forward(self, texts):
199
+ """
200
+ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
201
+ Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
202
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
203
+ An example shape returned by this function can be: (2, 77, 768).
204
+ Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
205
+ is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
206
+ """
207
+
208
+ if opts.use_old_emphasis_implementation:
209
+ import modules.sd_hijack_clip_old
210
+ return modules.sd_hijack_clip_old.forward_old(self, texts)
211
+
212
+ batch_chunks, token_count = self.process_texts(texts)
213
+
214
+ used_embeddings = {}
215
+ chunk_count = max([len(x) for x in batch_chunks])
216
+
217
+ zs = []
218
+ for i in range(chunk_count):
219
+ batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
220
+
221
+ tokens = [x.tokens for x in batch_chunk]
222
+ multipliers = [x.multipliers for x in batch_chunk]
223
+ self.hijack.fixes = [x.fixes for x in batch_chunk]
224
+
225
+ for fixes in self.hijack.fixes:
226
+ for position, embedding in fixes:
227
+ used_embeddings[embedding.name] = embedding
228
+
229
+ z = self.process_tokens(tokens, multipliers)
230
+ zs.append(z)
231
+
232
+ if len(used_embeddings) > 0:
233
+ embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
234
+ self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
235
+
236
+ return torch.hstack(zs)
237
+
238
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
239
+ """
240
+ sends one single prompt chunk to be encoded by transformers neural network.
241
+ remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
242
+ there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
243
+ Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
244
+ corresponds to one token.
245
+ """
246
+ tokens = torch.asarray(remade_batch_tokens).to(devices.device)
247
+
248
+ # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
249
+ if self.id_end != self.id_pad:
250
+ for batch_pos in range(len(remade_batch_tokens)):
251
+ index = remade_batch_tokens[batch_pos].index(self.id_end)
252
+ tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
253
+
254
+ z = self.encode_with_transformers(tokens)
255
+
256
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
257
+ batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
258
+ original_mean = z.mean()
259
+ z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
260
+ new_mean = z.mean()
261
+ z = z * (original_mean / new_mean)
262
+
263
+ return z
264
+
265
+
266
+ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
267
+ def __init__(self, wrapped, hijack):
268
+ super().__init__(wrapped, hijack)
269
+ self.tokenizer = wrapped.tokenizer
270
+
271
+ vocab = self.tokenizer.get_vocab()
272
+
273
+ self.comma_token = vocab.get(',</w>', None)
274
+
275
+ self.token_mults = {}
276
+ tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
277
+ for text, ident in tokens_with_parens:
278
+ mult = 1.0
279
+ for c in text:
280
+ if c == '[':
281
+ mult /= 1.1
282
+ if c == ']':
283
+ mult *= 1.1
284
+ if c == '(':
285
+ mult *= 1.1
286
+ if c == ')':
287
+ mult /= 1.1
288
+
289
+ if mult != 1.0:
290
+ self.token_mults[ident] = mult
291
+
292
+ self.id_start = self.wrapped.tokenizer.bos_token_id
293
+ self.id_end = self.wrapped.tokenizer.eos_token_id
294
+ self.id_pad = self.id_end
295
+
296
+ def tokenize(self, texts):
297
+ tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
298
+
299
+ return tokenized
300
+
301
+ def encode_with_transformers(self, tokens):
302
+ outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
303
+
304
+ if opts.CLIP_stop_at_last_layers > 1:
305
+ z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
306
+ z = self.wrapped.transformer.text_model.final_layer_norm(z)
307
+ else:
308
+ z = outputs.last_hidden_state
309
+
310
+ return z
311
+
312
+ def encode_embedding_init_text(self, init_text, nvpt):
313
+ embedding_layer = self.wrapped.transformer.text_model.embeddings
314
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
315
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
316
+
317
+ return embedded
sd/stable-diffusion-webui/modules/sd_hijack_clip_old.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import sd_hijack_clip
2
+ from modules import shared
3
+
4
+
5
+ def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
6
+ id_start = self.id_start
7
+ id_end = self.id_end
8
+ maxlen = self.wrapped.max_length # you get to stay at 77
9
+ used_custom_terms = []
10
+ remade_batch_tokens = []
11
+ hijack_comments = []
12
+ hijack_fixes = []
13
+ token_count = 0
14
+
15
+ cache = {}
16
+ batch_tokens = self.tokenize(texts)
17
+ batch_multipliers = []
18
+ for tokens in batch_tokens:
19
+ tuple_tokens = tuple(tokens)
20
+
21
+ if tuple_tokens in cache:
22
+ remade_tokens, fixes, multipliers = cache[tuple_tokens]
23
+ else:
24
+ fixes = []
25
+ remade_tokens = []
26
+ multipliers = []
27
+ mult = 1.0
28
+
29
+ i = 0
30
+ while i < len(tokens):
31
+ token = tokens[i]
32
+
33
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
34
+
35
+ mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
36
+ if mult_change is not None:
37
+ mult *= mult_change
38
+ i += 1
39
+ elif embedding is None:
40
+ remade_tokens.append(token)
41
+ multipliers.append(mult)
42
+ i += 1
43
+ else:
44
+ emb_len = int(embedding.vec.shape[0])
45
+ fixes.append((len(remade_tokens), embedding))
46
+ remade_tokens += [0] * emb_len
47
+ multipliers += [mult] * emb_len
48
+ used_custom_terms.append((embedding.name, embedding.checksum()))
49
+ i += embedding_length_in_tokens
50
+
51
+ if len(remade_tokens) > maxlen - 2:
52
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
53
+ ovf = remade_tokens[maxlen - 2:]
54
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
55
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
56
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
57
+
58
+ token_count = len(remade_tokens)
59
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
60
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
61
+ cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
62
+
63
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
64
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
65
+
66
+ remade_batch_tokens.append(remade_tokens)
67
+ hijack_fixes.append(fixes)
68
+ batch_multipliers.append(multipliers)
69
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
70
+
71
+
72
+ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
73
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
74
+
75
+ self.hijack.comments += hijack_comments
76
+
77
+ if len(used_custom_terms) > 0:
78
+ self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
79
+
80
+ self.hijack.fixes = hijack_fixes
81
+ return self.process_tokens(remade_batch_tokens, batch_multipliers)
sd/stable-diffusion-webui/modules/sd_hijack_inpainting.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from einops import repeat
5
+ from omegaconf import ListConfig
6
+
7
+ import ldm.models.diffusion.ddpm
8
+ import ldm.models.diffusion.ddim
9
+ import ldm.models.diffusion.plms
10
+
11
+ from ldm.models.diffusion.ddpm import LatentDiffusion
12
+ from ldm.models.diffusion.plms import PLMSSampler
13
+ from ldm.models.diffusion.ddim import DDIMSampler, noise_like
14
+ from ldm.models.diffusion.sampling_util import norm_thresholding
15
+
16
+
17
+ @torch.no_grad()
18
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
19
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
20
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
21
+ b, *_, device = *x.shape, x.device
22
+
23
+ def get_model_output(x, t):
24
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
25
+ e_t = self.model.apply_model(x, t, c)
26
+ else:
27
+ x_in = torch.cat([x] * 2)
28
+ t_in = torch.cat([t] * 2)
29
+
30
+ if isinstance(c, dict):
31
+ assert isinstance(unconditional_conditioning, dict)
32
+ c_in = dict()
33
+ for k in c:
34
+ if isinstance(c[k], list):
35
+ c_in[k] = [
36
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
37
+ for i in range(len(c[k]))
38
+ ]
39
+ else:
40
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
41
+ else:
42
+ c_in = torch.cat([unconditional_conditioning, c])
43
+
44
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
45
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
46
+
47
+ if score_corrector is not None:
48
+ assert self.model.parameterization == "eps"
49
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
50
+
51
+ return e_t
52
+
53
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
54
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
55
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
56
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
57
+
58
+ def get_x_prev_and_pred_x0(e_t, index):
59
+ # select parameters corresponding to the currently considered timestep
60
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
61
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
62
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
63
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
64
+
65
+ # current prediction for x_0
66
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
67
+ if quantize_denoised:
68
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
69
+ if dynamic_threshold is not None:
70
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
71
+ # direction pointing to x_t
72
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
73
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
74
+ if noise_dropout > 0.:
75
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
76
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
77
+ return x_prev, pred_x0
78
+
79
+ e_t = get_model_output(x, t)
80
+ if len(old_eps) == 0:
81
+ # Pseudo Improved Euler (2nd order)
82
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
83
+ e_t_next = get_model_output(x_prev, t_next)
84
+ e_t_prime = (e_t + e_t_next) / 2
85
+ elif len(old_eps) == 1:
86
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
87
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
88
+ elif len(old_eps) == 2:
89
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
90
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
91
+ elif len(old_eps) >= 3:
92
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
93
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
94
+
95
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
96
+
97
+ return x_prev, pred_x0, e_t
98
+
99
+
100
+ def do_inpainting_hijack():
101
+ # p_sample_plms is needed because PLMS can't work with dicts as conditionings
102
+
103
+ ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
sd/stable-diffusion-webui/modules/sd_hijack_ip2p.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os.path
3
+ import sys
4
+ import gc
5
+ import time
6
+
7
+ def should_hijack_ip2p(checkpoint_info):
8
+ from modules import sd_models_config
9
+
10
+ ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
11
+ cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
12
+
13
+ return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
sd/stable-diffusion-webui/modules/sd_hijack_open_clip.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open_clip.tokenizer
2
+ import torch
3
+
4
+ from modules import sd_hijack_clip, devices
5
+ from modules.shared import opts
6
+
7
+ tokenizer = open_clip.tokenizer._tokenizer
8
+
9
+
10
+ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
11
+ def __init__(self, wrapped, hijack):
12
+ super().__init__(wrapped, hijack)
13
+
14
+ self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
15
+ self.id_start = tokenizer.encoder["<start_of_text>"]
16
+ self.id_end = tokenizer.encoder["<end_of_text>"]
17
+ self.id_pad = 0
18
+
19
+ def tokenize(self, texts):
20
+ assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
21
+
22
+ tokenized = [tokenizer.encode(text) for text in texts]
23
+
24
+ return tokenized
25
+
26
+ def encode_with_transformers(self, tokens):
27
+ # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
28
+ z = self.wrapped.encode_with_transformer(tokens)
29
+
30
+ return z
31
+
32
+ def encode_embedding_init_text(self, init_text, nvpt):
33
+ ids = tokenizer.encode(init_text)
34
+ ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
35
+ embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
36
+
37
+ return embedded
sd/stable-diffusion-webui/modules/sd_hijack_optimizations.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+ import traceback
4
+ import psutil
5
+
6
+ import torch
7
+ from torch import einsum
8
+
9
+ from ldm.util import default
10
+ from einops import rearrange
11
+
12
+ from modules import shared, errors, devices
13
+ from modules.hypernetworks import hypernetwork
14
+
15
+ from .sub_quadratic_attention import efficient_dot_product_attention
16
+
17
+
18
+ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
19
+ try:
20
+ import xformers.ops
21
+ shared.xformers_available = True
22
+ except Exception:
23
+ print("Cannot import xformers", file=sys.stderr)
24
+ print(traceback.format_exc(), file=sys.stderr)
25
+
26
+
27
+ def get_available_vram():
28
+ if shared.device.type == 'cuda':
29
+ stats = torch.cuda.memory_stats(shared.device)
30
+ mem_active = stats['active_bytes.all.current']
31
+ mem_reserved = stats['reserved_bytes.all.current']
32
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
33
+ mem_free_torch = mem_reserved - mem_active
34
+ mem_free_total = mem_free_cuda + mem_free_torch
35
+ return mem_free_total
36
+ else:
37
+ return psutil.virtual_memory().available
38
+
39
+
40
+ # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
41
+ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
42
+ h = self.heads
43
+
44
+ q_in = self.to_q(x)
45
+ context = default(context, x)
46
+
47
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
48
+ k_in = self.to_k(context_k)
49
+ v_in = self.to_v(context_v)
50
+ del context, context_k, context_v, x
51
+
52
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
53
+ del q_in, k_in, v_in
54
+
55
+ dtype = q.dtype
56
+ if shared.opts.upcast_attn:
57
+ q, k, v = q.float(), k.float(), v.float()
58
+
59
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
60
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
61
+ for i in range(0, q.shape[0], 2):
62
+ end = i + 2
63
+ s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
64
+ s1 *= self.scale
65
+
66
+ s2 = s1.softmax(dim=-1)
67
+ del s1
68
+
69
+ r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
70
+ del s2
71
+ del q, k, v
72
+
73
+ r1 = r1.to(dtype)
74
+
75
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
76
+ del r1
77
+
78
+ return self.to_out(r2)
79
+
80
+
81
+ # taken from https://github.com/Doggettx/stable-diffusion and modified
82
+ def split_cross_attention_forward(self, x, context=None, mask=None):
83
+ h = self.heads
84
+
85
+ q_in = self.to_q(x)
86
+ context = default(context, x)
87
+
88
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
89
+ k_in = self.to_k(context_k)
90
+ v_in = self.to_v(context_v)
91
+
92
+ dtype = q_in.dtype
93
+ if shared.opts.upcast_attn:
94
+ q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
95
+
96
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
97
+ k_in = k_in * self.scale
98
+
99
+ del context, x
100
+
101
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
102
+ del q_in, k_in, v_in
103
+
104
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
105
+
106
+ mem_free_total = get_available_vram()
107
+
108
+ gb = 1024 ** 3
109
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
110
+ modifier = 3 if q.element_size() == 2 else 2.5
111
+ mem_required = tensor_size * modifier
112
+ steps = 1
113
+
114
+ if mem_required > mem_free_total:
115
+ steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
116
+ # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
117
+ # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
118
+
119
+ if steps > 64:
120
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
121
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
122
+ f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
123
+
124
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
125
+ for i in range(0, q.shape[1], slice_size):
126
+ end = i + slice_size
127
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
128
+
129
+ s2 = s1.softmax(dim=-1, dtype=q.dtype)
130
+ del s1
131
+
132
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
133
+ del s2
134
+
135
+ del q, k, v
136
+
137
+ r1 = r1.to(dtype)
138
+
139
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
140
+ del r1
141
+
142
+ return self.to_out(r2)
143
+
144
+
145
+ # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
146
+ mem_total_gb = psutil.virtual_memory().total // (1 << 30)
147
+
148
+ def einsum_op_compvis(q, k, v):
149
+ s = einsum('b i d, b j d -> b i j', q, k)
150
+ s = s.softmax(dim=-1, dtype=s.dtype)
151
+ return einsum('b i j, b j d -> b i d', s, v)
152
+
153
+ def einsum_op_slice_0(q, k, v, slice_size):
154
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
155
+ for i in range(0, q.shape[0], slice_size):
156
+ end = i + slice_size
157
+ r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
158
+ return r
159
+
160
+ def einsum_op_slice_1(q, k, v, slice_size):
161
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
162
+ for i in range(0, q.shape[1], slice_size):
163
+ end = i + slice_size
164
+ r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
165
+ return r
166
+
167
+ def einsum_op_mps_v1(q, k, v):
168
+ if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
169
+ return einsum_op_compvis(q, k, v)
170
+ else:
171
+ slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
172
+ if slice_size % 4096 == 0:
173
+ slice_size -= 1
174
+ return einsum_op_slice_1(q, k, v, slice_size)
175
+
176
+ def einsum_op_mps_v2(q, k, v):
177
+ if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
178
+ return einsum_op_compvis(q, k, v)
179
+ else:
180
+ return einsum_op_slice_0(q, k, v, 1)
181
+
182
+ def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
183
+ size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
184
+ if size_mb <= max_tensor_mb:
185
+ return einsum_op_compvis(q, k, v)
186
+ div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
187
+ if div <= q.shape[0]:
188
+ return einsum_op_slice_0(q, k, v, q.shape[0] // div)
189
+ return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
190
+
191
+ def einsum_op_cuda(q, k, v):
192
+ stats = torch.cuda.memory_stats(q.device)
193
+ mem_active = stats['active_bytes.all.current']
194
+ mem_reserved = stats['reserved_bytes.all.current']
195
+ mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
196
+ mem_free_torch = mem_reserved - mem_active
197
+ mem_free_total = mem_free_cuda + mem_free_torch
198
+ # Divide factor of safety as there's copying and fragmentation
199
+ return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
200
+
201
+ def einsum_op(q, k, v):
202
+ if q.device.type == 'cuda':
203
+ return einsum_op_cuda(q, k, v)
204
+
205
+ if q.device.type == 'mps':
206
+ if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
207
+ return einsum_op_mps_v1(q, k, v)
208
+ return einsum_op_mps_v2(q, k, v)
209
+
210
+ # Smaller slices are faster due to L2/L3/SLC caches.
211
+ # Tested on i7 with 8MB L3 cache.
212
+ return einsum_op_tensor_mem(q, k, v, 32)
213
+
214
+ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
215
+ h = self.heads
216
+
217
+ q = self.to_q(x)
218
+ context = default(context, x)
219
+
220
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
221
+ k = self.to_k(context_k)
222
+ v = self.to_v(context_v)
223
+ del context, context_k, context_v, x
224
+
225
+ dtype = q.dtype
226
+ if shared.opts.upcast_attn:
227
+ q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
228
+
229
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
230
+ k = k * self.scale
231
+
232
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
233
+ r = einsum_op(q, k, v)
234
+ r = r.to(dtype)
235
+ return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
236
+
237
+ # -- End of code from https://github.com/invoke-ai/InvokeAI --
238
+
239
+
240
+ # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
241
+ # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
242
+ def sub_quad_attention_forward(self, x, context=None, mask=None):
243
+ assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
244
+
245
+ h = self.heads
246
+
247
+ q = self.to_q(x)
248
+ context = default(context, x)
249
+
250
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
251
+ k = self.to_k(context_k)
252
+ v = self.to_v(context_v)
253
+ del context, context_k, context_v, x
254
+
255
+ q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
256
+ k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
257
+ v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
258
+
259
+ dtype = q.dtype
260
+ if shared.opts.upcast_attn:
261
+ q, k = q.float(), k.float()
262
+
263
+ x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
264
+
265
+ x = x.to(dtype)
266
+
267
+ x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
268
+
269
+ out_proj, dropout = self.to_out
270
+ x = out_proj(x)
271
+ x = dropout(x)
272
+
273
+ return x
274
+
275
+ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
276
+ bytes_per_token = torch.finfo(q.dtype).bits//8
277
+ batch_x_heads, q_tokens, _ = q.shape
278
+ _, k_tokens, _ = k.shape
279
+ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
280
+
281
+ if chunk_threshold is None:
282
+ chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
283
+ elif chunk_threshold == 0:
284
+ chunk_threshold_bytes = None
285
+ else:
286
+ chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
287
+
288
+ if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
289
+ kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
290
+ elif kv_chunk_size_min == 0:
291
+ kv_chunk_size_min = None
292
+
293
+ if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
294
+ # the big matmul fits into our memory limit; do everything in 1 chunk,
295
+ # i.e. send it down the unchunked fast-path
296
+ query_chunk_size = q_tokens
297
+ kv_chunk_size = k_tokens
298
+
299
+ with devices.without_autocast(disable=q.dtype == v.dtype):
300
+ return efficient_dot_product_attention(
301
+ q,
302
+ k,
303
+ v,
304
+ query_chunk_size=q_chunk_size,
305
+ kv_chunk_size=kv_chunk_size,
306
+ kv_chunk_size_min = kv_chunk_size_min,
307
+ use_checkpoint=use_checkpoint,
308
+ )
309
+
310
+
311
+ def get_xformers_flash_attention_op(q, k, v):
312
+ if not shared.cmd_opts.xformers_flash_attention:
313
+ return None
314
+
315
+ try:
316
+ flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
317
+ fw, bw = flash_attention_op
318
+ if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
319
+ return flash_attention_op
320
+ except Exception as e:
321
+ errors.display_once(e, "enabling flash attention")
322
+
323
+ return None
324
+
325
+
326
+ def xformers_attention_forward(self, x, context=None, mask=None):
327
+ h = self.heads
328
+ q_in = self.to_q(x)
329
+ context = default(context, x)
330
+
331
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
332
+ k_in = self.to_k(context_k)
333
+ v_in = self.to_v(context_v)
334
+
335
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
336
+ del q_in, k_in, v_in
337
+
338
+ dtype = q.dtype
339
+ if shared.opts.upcast_attn:
340
+ q, k = q.float(), k.float()
341
+
342
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
343
+
344
+ out = out.to(dtype)
345
+
346
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
347
+ return self.to_out(out)
348
+
349
+ def cross_attention_attnblock_forward(self, x):
350
+ h_ = x
351
+ h_ = self.norm(h_)
352
+ q1 = self.q(h_)
353
+ k1 = self.k(h_)
354
+ v = self.v(h_)
355
+
356
+ # compute attention
357
+ b, c, h, w = q1.shape
358
+
359
+ q2 = q1.reshape(b, c, h*w)
360
+ del q1
361
+
362
+ q = q2.permute(0, 2, 1) # b,hw,c
363
+ del q2
364
+
365
+ k = k1.reshape(b, c, h*w) # b,c,hw
366
+ del k1
367
+
368
+ h_ = torch.zeros_like(k, device=q.device)
369
+
370
+ mem_free_total = get_available_vram()
371
+
372
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
373
+ mem_required = tensor_size * 2.5
374
+ steps = 1
375
+
376
+ if mem_required > mem_free_total:
377
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
378
+
379
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
380
+ for i in range(0, q.shape[1], slice_size):
381
+ end = i + slice_size
382
+
383
+ w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
384
+ w2 = w1 * (int(c)**(-0.5))
385
+ del w1
386
+ w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
387
+ del w2
388
+
389
+ # attend to values
390
+ v1 = v.reshape(b, c, h*w)
391
+ w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
392
+ del w3
393
+
394
+ h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
395
+ del v1, w4
396
+
397
+ h2 = h_.reshape(b, c, h, w)
398
+ del h_
399
+
400
+ h3 = self.proj_out(h2)
401
+ del h2
402
+
403
+ h3 += x
404
+
405
+ return h3
406
+
407
+ def xformers_attnblock_forward(self, x):
408
+ try:
409
+ h_ = x
410
+ h_ = self.norm(h_)
411
+ q = self.q(h_)
412
+ k = self.k(h_)
413
+ v = self.v(h_)
414
+ b, c, h, w = q.shape
415
+ q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
416
+ dtype = q.dtype
417
+ if shared.opts.upcast_attn:
418
+ q, k = q.float(), k.float()
419
+ q = q.contiguous()
420
+ k = k.contiguous()
421
+ v = v.contiguous()
422
+ out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
423
+ out = out.to(dtype)
424
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
425
+ out = self.proj_out(out)
426
+ return x + out
427
+ except NotImplementedError:
428
+ return cross_attention_attnblock_forward(self, x)
429
+
430
+ def sub_quad_attnblock_forward(self, x):
431
+ h_ = x
432
+ h_ = self.norm(h_)
433
+ q = self.q(h_)
434
+ k = self.k(h_)
435
+ v = self.v(h_)
436
+ b, c, h, w = q.shape
437
+ q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
438
+ q = q.contiguous()
439
+ k = k.contiguous()
440
+ v = v.contiguous()
441
+ out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
442
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
443
+ out = self.proj_out(out)
444
+ return x + out
sd/stable-diffusion-webui/modules/sd_hijack_unet.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from packaging import version
3
+
4
+ from modules import devices
5
+ from modules.sd_hijack_utils import CondFunc
6
+
7
+
8
+ class TorchHijackForUnet:
9
+ """
10
+ This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
11
+ this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
12
+ """
13
+
14
+ def __getattr__(self, item):
15
+ if item == 'cat':
16
+ return self.cat
17
+
18
+ if hasattr(torch, item):
19
+ return getattr(torch, item)
20
+
21
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
22
+
23
+ def cat(self, tensors, *args, **kwargs):
24
+ if len(tensors) == 2:
25
+ a, b = tensors
26
+ if a.shape[-2:] != b.shape[-2:]:
27
+ a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
28
+
29
+ tensors = (a, b)
30
+
31
+ return torch.cat(tensors, *args, **kwargs)
32
+
33
+
34
+ th = TorchHijackForUnet()
35
+
36
+
37
+ # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
38
+ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
39
+
40
+ if isinstance(cond, dict):
41
+ for y in cond.keys():
42
+ cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
43
+
44
+ with devices.autocast():
45
+ return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
46
+
47
+
48
+ class GELUHijack(torch.nn.GELU, torch.nn.Module):
49
+ def __init__(self, *args, **kwargs):
50
+ torch.nn.GELU.__init__(self, *args, **kwargs)
51
+ def forward(self, x):
52
+ if devices.unet_needs_upcast:
53
+ return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
54
+ else:
55
+ return torch.nn.GELU.forward(self, x)
56
+
57
+
58
+ ddpm_edit_hijack = None
59
+ def hijack_ddpm_edit():
60
+ global ddpm_edit_hijack
61
+ if not ddpm_edit_hijack:
62
+ CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
63
+ CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
64
+ ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
65
+
66
+
67
+ unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
68
+ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
69
+ CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
70
+ if version.parse(torch.__version__) <= version.parse("1.13.1"):
71
+ CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
72
+ CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
73
+ CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
74
+
75
+ first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
76
+ first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
77
+ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
78
+ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
79
+ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
sd/stable-diffusion-webui/modules/sd_hijack_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ class CondFunc:
4
+ def __new__(cls, orig_func, sub_func, cond_func):
5
+ self = super(CondFunc, cls).__new__(cls)
6
+ if isinstance(orig_func, str):
7
+ func_path = orig_func.split('.')
8
+ for i in range(len(func_path)-1, -1, -1):
9
+ try:
10
+ resolved_obj = importlib.import_module('.'.join(func_path[:i]))
11
+ break
12
+ except ImportError:
13
+ pass
14
+ for attr_name in func_path[i:-1]:
15
+ resolved_obj = getattr(resolved_obj, attr_name)
16
+ orig_func = getattr(resolved_obj, func_path[-1])
17
+ setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
18
+ self.__init__(orig_func, sub_func, cond_func)
19
+ return lambda *args, **kwargs: self(*args, **kwargs)
20
+ def __init__(self, orig_func, sub_func, cond_func):
21
+ self.__orig_func = orig_func
22
+ self.__sub_func = sub_func
23
+ self.__cond_func = cond_func
24
+ def __call__(self, *args, **kwargs):
25
+ if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
26
+ return self.__sub_func(self.__orig_func, *args, **kwargs)
27
+ else:
28
+ return self.__orig_func(*args, **kwargs)