toto10 commited on
Commit
449cca0
1 Parent(s): 5002a4e

e2079503703a6d257b7f98add64aa93c94c0f12610c7aca1e434fe98ffb56c3c

Browse files
Files changed (50) hide show
  1. modules/progress.py +129 -0
  2. modules/prompt_parser.py +439 -0
  3. modules/realesrgan_model.py +132 -0
  4. modules/restart.py +23 -0
  5. modules/safe.py +196 -0
  6. modules/script_callbacks.py +453 -0
  7. modules/script_loading.py +31 -0
  8. modules/scripts.py +680 -0
  9. modules/scripts_auto_postprocessing.py +42 -0
  10. modules/scripts_postprocessing.py +152 -0
  11. modules/sd_disable_initialization.py +93 -0
  12. modules/sd_hijack.py +346 -0
  13. modules/sd_hijack_checkpoint.py +46 -0
  14. modules/sd_hijack_clip.py +349 -0
  15. modules/sd_hijack_clip_old.py +82 -0
  16. modules/sd_hijack_inpainting.py +97 -0
  17. modules/sd_hijack_ip2p.py +10 -0
  18. modules/sd_hijack_open_clip.py +71 -0
  19. modules/sd_hijack_optimizations.py +668 -0
  20. modules/sd_hijack_unet.py +85 -0
  21. modules/sd_hijack_utils.py +28 -0
  22. modules/sd_hijack_xlmr.py +32 -0
  23. modules/sd_models.py +643 -0
  24. modules/sd_models_config.py +125 -0
  25. modules/sd_models_xl.py +99 -0
  26. modules/sd_samplers.py +56 -0
  27. modules/sd_samplers_common.py +95 -0
  28. modules/sd_samplers_compvis.py +224 -0
  29. modules/sd_samplers_kdiffusion.py +476 -0
  30. modules/sd_unet.py +92 -0
  31. modules/sd_vae.py +213 -0
  32. modules/sd_vae_approx.py +86 -0
  33. modules/sd_vae_taesd.py +88 -0
  34. modules/shared.py +912 -0
  35. modules/shared_items.py +69 -0
  36. modules/styles.py +139 -0
  37. modules/sub_quadratic_attention.py +215 -0
  38. modules/sysinfo.py +162 -0
  39. modules/textual_inversion/__pycache__/autocrop.cpython-310.pyc +0 -0
  40. modules/textual_inversion/__pycache__/dataset.cpython-310.pyc +0 -0
  41. modules/textual_inversion/__pycache__/image_embedding.cpython-310.pyc +0 -0
  42. modules/textual_inversion/__pycache__/learn_schedule.cpython-310.pyc +0 -0
  43. modules/textual_inversion/__pycache__/logging.cpython-310.pyc +0 -0
  44. modules/textual_inversion/__pycache__/preprocess.cpython-310.pyc +0 -0
  45. modules/textual_inversion/__pycache__/textual_inversion.cpython-310.pyc +0 -0
  46. modules/textual_inversion/__pycache__/ui.cpython-310.pyc +0 -0
  47. modules/textual_inversion/autocrop.py +340 -0
  48. modules/textual_inversion/dataset.py +246 -0
  49. modules/textual_inversion/image_embedding.py +220 -0
  50. modules/textual_inversion/learn_schedule.py +81 -0
modules/progress.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import time
4
+
5
+ import gradio as gr
6
+ from pydantic import BaseModel, Field
7
+
8
+ from modules.shared import opts
9
+
10
+ import modules.shared as shared
11
+
12
+
13
+ current_task = None
14
+ pending_tasks = {}
15
+ finished_tasks = []
16
+ recorded_results = []
17
+ recorded_results_limit = 2
18
+
19
+
20
+ def start_task(id_task):
21
+ global current_task
22
+
23
+ current_task = id_task
24
+ pending_tasks.pop(id_task, None)
25
+
26
+
27
+ def finish_task(id_task):
28
+ global current_task
29
+
30
+ if current_task == id_task:
31
+ current_task = None
32
+
33
+ finished_tasks.append(id_task)
34
+ if len(finished_tasks) > 16:
35
+ finished_tasks.pop(0)
36
+
37
+
38
+ def record_results(id_task, res):
39
+ recorded_results.append((id_task, res))
40
+ if len(recorded_results) > recorded_results_limit:
41
+ recorded_results.pop(0)
42
+
43
+
44
+ def add_task_to_queue(id_job):
45
+ pending_tasks[id_job] = time.time()
46
+
47
+
48
+ class ProgressRequest(BaseModel):
49
+ id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
50
+ id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
51
+
52
+
53
+ class ProgressResponse(BaseModel):
54
+ active: bool = Field(title="Whether the task is being worked on right now")
55
+ queued: bool = Field(title="Whether the task is in queue")
56
+ completed: bool = Field(title="Whether the task has already finished")
57
+ progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
58
+ eta: float = Field(default=None, title="ETA in secs")
59
+ live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
60
+ id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
61
+ textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
62
+
63
+
64
+ def setup_progress_api(app):
65
+ return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
66
+
67
+
68
+ def progressapi(req: ProgressRequest):
69
+ active = req.id_task == current_task
70
+ queued = req.id_task in pending_tasks
71
+ completed = req.id_task in finished_tasks
72
+
73
+ if not active:
74
+ return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
75
+
76
+ progress = 0
77
+
78
+ job_count, job_no = shared.state.job_count, shared.state.job_no
79
+ sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
80
+
81
+ if job_count > 0:
82
+ progress += job_no / job_count
83
+ if sampling_steps > 0 and job_count > 0:
84
+ progress += 1 / job_count * sampling_step / sampling_steps
85
+
86
+ progress = min(progress, 1)
87
+
88
+ elapsed_since_start = time.time() - shared.state.time_start
89
+ predicted_duration = elapsed_since_start / progress if progress > 0 else None
90
+ eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
91
+
92
+ id_live_preview = req.id_live_preview
93
+ shared.state.set_current_image()
94
+ if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
95
+ image = shared.state.current_image
96
+ if image is not None:
97
+ buffered = io.BytesIO()
98
+
99
+ if opts.live_previews_image_format == "png":
100
+ # using optimize for large images takes an enormous amount of time
101
+ if max(*image.size) <= 256:
102
+ save_kwargs = {"optimize": True}
103
+ else:
104
+ save_kwargs = {"optimize": False, "compress_level": 1}
105
+
106
+ else:
107
+ save_kwargs = {}
108
+
109
+ image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
110
+ base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
111
+ live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
112
+ id_live_preview = shared.state.id_live_preview
113
+ else:
114
+ live_preview = None
115
+ else:
116
+ live_preview = None
117
+
118
+ return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
119
+
120
+
121
+ def restore_progress(id_task):
122
+ while id_task == current_task or id_task in pending_tasks:
123
+ time.sleep(0.1)
124
+
125
+ res = next(iter([x[1] for x in recorded_results if id_task == x[0]]), None)
126
+ if res is not None:
127
+ return res
128
+
129
+ return gr.update(), gr.update(), gr.update(), f"Couldn't restore progress for {id_task}: results either have been discarded or never were obtained"
modules/prompt_parser.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from collections import namedtuple
5
+ from typing import List
6
+ import lark
7
+
8
+ # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
9
+ # will be represented with prompt_schedule like this (assuming steps=100):
10
+ # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
11
+ # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
12
+ # [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
13
+ # [75, 'fantasy landscape with a lake and an oak in background masterful']
14
+ # [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
15
+
16
+ schedule_parser = lark.Lark(r"""
17
+ !start: (prompt | /[][():]/+)*
18
+ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
19
+ !emphasized: "(" prompt ")"
20
+ | "(" prompt ":" prompt ")"
21
+ | "[" prompt "]"
22
+ scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
23
+ alternate: "[" prompt ("|" prompt)+ "]"
24
+ WHITESPACE: /\s+/
25
+ plain: /([^\\\[\]():|]|\\.)+/
26
+ %import common.SIGNED_NUMBER -> NUMBER
27
+ """)
28
+
29
+ def get_learned_conditioning_prompt_schedules(prompts, steps):
30
+ """
31
+ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
32
+ >>> g("test")
33
+ [[10, 'test']]
34
+ >>> g("a [b:3]")
35
+ [[3, 'a '], [10, 'a b']]
36
+ >>> g("a [b: 3]")
37
+ [[3, 'a '], [10, 'a b']]
38
+ >>> g("a [[[b]]:2]")
39
+ [[2, 'a '], [10, 'a [[b]]']]
40
+ >>> g("[(a:2):3]")
41
+ [[3, ''], [10, '(a:2)']]
42
+ >>> g("a [b : c : 1] d")
43
+ [[1, 'a b d'], [10, 'a c d']]
44
+ >>> g("a[b:[c:d:2]:1]e")
45
+ [[1, 'abe'], [2, 'ace'], [10, 'ade']]
46
+ >>> g("a [unbalanced")
47
+ [[10, 'a [unbalanced']]
48
+ >>> g("a [b:.5] c")
49
+ [[5, 'a c'], [10, 'a b c']]
50
+ >>> g("a [{b|d{:.5] c") # not handling this right now
51
+ [[5, 'a c'], [10, 'a {b|d{ c']]
52
+ >>> g("((a][:b:c [d:3]")
53
+ [[3, '((a][:b:c '], [10, '((a][:b:c d']]
54
+ >>> g("[a|(b:1.1)]")
55
+ [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
56
+ """
57
+
58
+ def collect_steps(steps, tree):
59
+ res = [steps]
60
+
61
+ class CollectSteps(lark.Visitor):
62
+ def scheduled(self, tree):
63
+ tree.children[-1] = float(tree.children[-1])
64
+ if tree.children[-1] < 1:
65
+ tree.children[-1] *= steps
66
+ tree.children[-1] = min(steps, int(tree.children[-1]))
67
+ res.append(tree.children[-1])
68
+
69
+ def alternate(self, tree):
70
+ res.extend(range(1, steps+1))
71
+
72
+ CollectSteps().visit(tree)
73
+ return sorted(set(res))
74
+
75
+ def at_step(step, tree):
76
+ class AtStep(lark.Transformer):
77
+ def scheduled(self, args):
78
+ before, after, _, when = args
79
+ yield before or () if step <= when else after
80
+ def alternate(self, args):
81
+ yield next(args[(step - 1)%len(args)])
82
+ def start(self, args):
83
+ def flatten(x):
84
+ if type(x) == str:
85
+ yield x
86
+ else:
87
+ for gen in x:
88
+ yield from flatten(gen)
89
+ return ''.join(flatten(args))
90
+ def plain(self, args):
91
+ yield args[0].value
92
+ def __default__(self, data, children, meta):
93
+ for child in children:
94
+ yield child
95
+ return AtStep().transform(tree)
96
+
97
+ def get_schedule(prompt):
98
+ try:
99
+ tree = schedule_parser.parse(prompt)
100
+ except lark.exceptions.LarkError:
101
+ if 0:
102
+ import traceback
103
+ traceback.print_exc()
104
+ return [[steps, prompt]]
105
+ return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
106
+
107
+ promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
108
+ return [promptdict[prompt] for prompt in prompts]
109
+
110
+
111
+ ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
112
+
113
+
114
+ class SdConditioning(list):
115
+ """
116
+ A list with prompts for stable diffusion's conditioner model.
117
+ Can also specify width and height of created image - SDXL needs it.
118
+ """
119
+ def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
120
+ super().__init__()
121
+ self.extend(prompts)
122
+
123
+ if copy_from is None:
124
+ copy_from = prompts
125
+
126
+ self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
127
+ self.width = width or getattr(copy_from, 'width', None)
128
+ self.height = height or getattr(copy_from, 'height', None)
129
+
130
+
131
+
132
+ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
133
+ """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
134
+ and the sampling step at which this condition is to be replaced by the next one.
135
+
136
+ Input:
137
+ (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
138
+
139
+ Output:
140
+ [
141
+ [
142
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
143
+ ],
144
+ [
145
+ ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
146
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
147
+ ]
148
+ ]
149
+ """
150
+ res = []
151
+
152
+ prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
153
+ cache = {}
154
+
155
+ for prompt, prompt_schedule in zip(prompts, prompt_schedules):
156
+
157
+ cached = cache.get(prompt, None)
158
+ if cached is not None:
159
+ res.append(cached)
160
+ continue
161
+
162
+ texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
163
+ conds = model.get_learned_conditioning(texts)
164
+
165
+ cond_schedule = []
166
+ for i, (end_at_step, _) in enumerate(prompt_schedule):
167
+ if isinstance(conds, dict):
168
+ cond = {k: v[i] for k, v in conds.items()}
169
+ else:
170
+ cond = conds[i]
171
+
172
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
173
+
174
+ cache[prompt] = cond_schedule
175
+ res.append(cond_schedule)
176
+
177
+ return res
178
+
179
+
180
+ re_AND = re.compile(r"\bAND\b")
181
+ re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
182
+
183
+
184
+ def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
185
+ res_indexes = []
186
+
187
+ prompt_indexes = {}
188
+ prompt_flat_list = SdConditioning(prompts)
189
+ prompt_flat_list.clear()
190
+
191
+ for prompt in prompts:
192
+ subprompts = re_AND.split(prompt)
193
+
194
+ indexes = []
195
+ for subprompt in subprompts:
196
+ match = re_weight.search(subprompt)
197
+
198
+ text, weight = match.groups() if match is not None else (subprompt, 1.0)
199
+
200
+ weight = float(weight) if weight is not None else 1.0
201
+
202
+ index = prompt_indexes.get(text, None)
203
+ if index is None:
204
+ index = len(prompt_flat_list)
205
+ prompt_flat_list.append(text)
206
+ prompt_indexes[text] = index
207
+
208
+ indexes.append((index, weight))
209
+
210
+ res_indexes.append(indexes)
211
+
212
+ return res_indexes, prompt_flat_list, prompt_indexes
213
+
214
+
215
+ class ComposableScheduledPromptConditioning:
216
+ def __init__(self, schedules, weight=1.0):
217
+ self.schedules: List[ScheduledPromptConditioning] = schedules
218
+ self.weight: float = weight
219
+
220
+
221
+ class MulticondLearnedConditioning:
222
+ def __init__(self, shape, batch):
223
+ self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
224
+ self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
225
+
226
+
227
+ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
228
+ """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
229
+ For each prompt, the list is obtained by splitting the prompt using the AND separator.
230
+
231
+ https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
232
+ """
233
+
234
+ res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
235
+
236
+ learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
237
+
238
+ res = []
239
+ for indexes in res_indexes:
240
+ res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
241
+
242
+ return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
243
+
244
+
245
+ class DictWithShape(dict):
246
+ def __init__(self, x, shape):
247
+ super().__init__()
248
+ self.update(x)
249
+
250
+ @property
251
+ def shape(self):
252
+ return self["crossattn"].shape
253
+
254
+
255
+ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
256
+ param = c[0][0].cond
257
+ is_dict = isinstance(param, dict)
258
+
259
+ if is_dict:
260
+ dict_cond = param
261
+ res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
262
+ res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
263
+ else:
264
+ res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
265
+
266
+ for i, cond_schedule in enumerate(c):
267
+ target_index = 0
268
+ for current, entry in enumerate(cond_schedule):
269
+ if current_step <= entry.end_at_step:
270
+ target_index = current
271
+ break
272
+
273
+ if is_dict:
274
+ for k, param in cond_schedule[target_index].cond.items():
275
+ res[k][i] = param
276
+ else:
277
+ res[i] = cond_schedule[target_index].cond
278
+
279
+ return res
280
+
281
+
282
+ def stack_conds(tensors):
283
+ # if prompts have wildly different lengths above the limit we'll get tensors of different shapes
284
+ # and won't be able to torch.stack them. So this fixes that.
285
+ token_count = max([x.shape[0] for x in tensors])
286
+ for i in range(len(tensors)):
287
+ if tensors[i].shape[0] != token_count:
288
+ last_vector = tensors[i][-1:]
289
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
290
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
291
+
292
+ return torch.stack(tensors)
293
+
294
+
295
+
296
+ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
297
+ param = c.batch[0][0].schedules[0].cond
298
+
299
+ tensors = []
300
+ conds_list = []
301
+
302
+ for composable_prompts in c.batch:
303
+ conds_for_batch = []
304
+
305
+ for composable_prompt in composable_prompts:
306
+ target_index = 0
307
+ for current, entry in enumerate(composable_prompt.schedules):
308
+ if current_step <= entry.end_at_step:
309
+ target_index = current
310
+ break
311
+
312
+ conds_for_batch.append((len(tensors), composable_prompt.weight))
313
+ tensors.append(composable_prompt.schedules[target_index].cond)
314
+
315
+ conds_list.append(conds_for_batch)
316
+
317
+ if isinstance(tensors[0], dict):
318
+ keys = list(tensors[0].keys())
319
+ stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
320
+ stacked = DictWithShape(stacked, stacked['crossattn'].shape)
321
+ else:
322
+ stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
323
+
324
+ return conds_list, stacked
325
+
326
+
327
+ re_attention = re.compile(r"""
328
+ \\\(|
329
+ \\\)|
330
+ \\\[|
331
+ \\]|
332
+ \\\\|
333
+ \\|
334
+ \(|
335
+ \[|
336
+ :([+-]?[.\d]+)\)|
337
+ \)|
338
+ ]|
339
+ [^\\()\[\]:]+|
340
+ :
341
+ """, re.X)
342
+
343
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
344
+
345
+ def parse_prompt_attention(text):
346
+ """
347
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
348
+ Accepted tokens are:
349
+ (abc) - increases attention to abc by a multiplier of 1.1
350
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
351
+ [abc] - decreases attention to abc by a multiplier of 1.1
352
+ \( - literal character '('
353
+ \[ - literal character '['
354
+ \) - literal character ')'
355
+ \] - literal character ']'
356
+ \\ - literal character '\'
357
+ anything else - just text
358
+
359
+ >>> parse_prompt_attention('normal text')
360
+ [['normal text', 1.0]]
361
+ >>> parse_prompt_attention('an (important) word')
362
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
363
+ >>> parse_prompt_attention('(unbalanced')
364
+ [['unbalanced', 1.1]]
365
+ >>> parse_prompt_attention('\(literal\]')
366
+ [['(literal]', 1.0]]
367
+ >>> parse_prompt_attention('(unnecessary)(parens)')
368
+ [['unnecessaryparens', 1.1]]
369
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
370
+ [['a ', 1.0],
371
+ ['house', 1.5730000000000004],
372
+ [' ', 1.1],
373
+ ['on', 1.0],
374
+ [' a ', 1.1],
375
+ ['hill', 0.55],
376
+ [', sun, ', 1.1],
377
+ ['sky', 1.4641000000000006],
378
+ ['.', 1.1]]
379
+ """
380
+
381
+ res = []
382
+ round_brackets = []
383
+ square_brackets = []
384
+
385
+ round_bracket_multiplier = 1.1
386
+ square_bracket_multiplier = 1 / 1.1
387
+
388
+ def multiply_range(start_position, multiplier):
389
+ for p in range(start_position, len(res)):
390
+ res[p][1] *= multiplier
391
+
392
+ for m in re_attention.finditer(text):
393
+ text = m.group(0)
394
+ weight = m.group(1)
395
+
396
+ if text.startswith('\\'):
397
+ res.append([text[1:], 1.0])
398
+ elif text == '(':
399
+ round_brackets.append(len(res))
400
+ elif text == '[':
401
+ square_brackets.append(len(res))
402
+ elif weight is not None and round_brackets:
403
+ multiply_range(round_brackets.pop(), float(weight))
404
+ elif text == ')' and round_brackets:
405
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
406
+ elif text == ']' and square_brackets:
407
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
408
+ else:
409
+ parts = re.split(re_break, text)
410
+ for i, part in enumerate(parts):
411
+ if i > 0:
412
+ res.append(["BREAK", -1])
413
+ res.append([part, 1.0])
414
+
415
+ for pos in round_brackets:
416
+ multiply_range(pos, round_bracket_multiplier)
417
+
418
+ for pos in square_brackets:
419
+ multiply_range(pos, square_bracket_multiplier)
420
+
421
+ if len(res) == 0:
422
+ res = [["", 1.0]]
423
+
424
+ # merge runs of identical weights
425
+ i = 0
426
+ while i + 1 < len(res):
427
+ if res[i][1] == res[i + 1][1]:
428
+ res[i][0] += res[i + 1][0]
429
+ res.pop(i + 1)
430
+ else:
431
+ i += 1
432
+
433
+ return res
434
+
435
+ if __name__ == "__main__":
436
+ import doctest
437
+ doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
438
+ else:
439
+ import torch # doctest faster
modules/realesrgan_model.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from realesrgan import RealESRGANer
6
+
7
+ from modules.upscaler import Upscaler, UpscalerData
8
+ from modules.shared import cmd_opts, opts
9
+ from modules import modelloader, errors
10
+
11
+
12
+ class UpscalerRealESRGAN(Upscaler):
13
+ def __init__(self, path):
14
+ self.name = "RealESRGAN"
15
+ self.user_path = path
16
+ super().__init__()
17
+ try:
18
+ from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
19
+ from realesrgan import RealESRGANer # noqa: F401
20
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
21
+ self.enable = True
22
+ self.scalers = []
23
+ scalers = self.load_models(path)
24
+
25
+ local_model_paths = self.find_models(ext_filter=[".pth"])
26
+ for scaler in scalers:
27
+ if scaler.local_data_path.startswith("http"):
28
+ filename = modelloader.friendly_name(scaler.local_data_path)
29
+ local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
30
+ if local_model_candidates:
31
+ scaler.local_data_path = local_model_candidates[0]
32
+
33
+ if scaler.name in opts.realesrgan_enabled_models:
34
+ self.scalers.append(scaler)
35
+
36
+ except Exception:
37
+ errors.report("Error importing Real-ESRGAN", exc_info=True)
38
+ self.enable = False
39
+ self.scalers = []
40
+
41
+ def do_upscale(self, img, path):
42
+ if not self.enable:
43
+ return img
44
+
45
+ try:
46
+ info = self.load_model(path)
47
+ except Exception:
48
+ errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
49
+ return img
50
+
51
+ upsampler = RealESRGANer(
52
+ scale=info.scale,
53
+ model_path=info.local_data_path,
54
+ model=info.model(),
55
+ half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
56
+ tile=opts.ESRGAN_tile,
57
+ tile_pad=opts.ESRGAN_tile_overlap,
58
+ )
59
+
60
+ upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
61
+
62
+ image = Image.fromarray(upsampled)
63
+ return image
64
+
65
+ def load_model(self, path):
66
+ for scaler in self.scalers:
67
+ if scaler.data_path == path:
68
+ if scaler.local_data_path.startswith("http"):
69
+ scaler.local_data_path = modelloader.load_file_from_url(
70
+ scaler.data_path,
71
+ model_dir=self.model_download_path,
72
+ )
73
+ if not os.path.exists(scaler.local_data_path):
74
+ raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
75
+ return scaler
76
+ raise ValueError(f"Unable to find model info: {path}")
77
+
78
+ def load_models(self, _):
79
+ return get_realesrgan_models(self)
80
+
81
+
82
+ def get_realesrgan_models(scaler):
83
+ try:
84
+ from basicsr.archs.rrdbnet_arch import RRDBNet
85
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
86
+ models = [
87
+ UpscalerData(
88
+ name="R-ESRGAN General 4xV3",
89
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
90
+ scale=4,
91
+ upscaler=scaler,
92
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
93
+ ),
94
+ UpscalerData(
95
+ name="R-ESRGAN General WDN 4xV3",
96
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
97
+ scale=4,
98
+ upscaler=scaler,
99
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
100
+ ),
101
+ UpscalerData(
102
+ name="R-ESRGAN AnimeVideo",
103
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
104
+ scale=4,
105
+ upscaler=scaler,
106
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
107
+ ),
108
+ UpscalerData(
109
+ name="R-ESRGAN 4x+",
110
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
111
+ scale=4,
112
+ upscaler=scaler,
113
+ model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
114
+ ),
115
+ UpscalerData(
116
+ name="R-ESRGAN 4x+ Anime6B",
117
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
118
+ scale=4,
119
+ upscaler=scaler,
120
+ model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
121
+ ),
122
+ UpscalerData(
123
+ name="R-ESRGAN 2x+",
124
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
125
+ scale=2,
126
+ upscaler=scaler,
127
+ model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
128
+ ),
129
+ ]
130
+ return models
131
+ except Exception:
132
+ errors.report("Error making Real-ESRGAN models list", exc_info=True)
modules/restart.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from modules.paths_internal import script_path
5
+
6
+
7
+ def is_restartable() -> bool:
8
+ """
9
+ Return True if the webui is restartable (i.e. there is something watching to restart it with)
10
+ """
11
+ return bool(os.environ.get('SD_WEBUI_RESTART'))
12
+
13
+
14
+ def restart_program() -> None:
15
+ """creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
16
+
17
+ (Path(script_path) / "tmp" / "restart").touch()
18
+
19
+ stop_program()
20
+
21
+
22
+ def stop_program() -> None:
23
+ os._exit(0)
modules/safe.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this code is adapted from the script contributed by anon from /h/
2
+
3
+ import pickle
4
+ import collections
5
+
6
+ import torch
7
+ import numpy
8
+ import _codecs
9
+ import zipfile
10
+ import re
11
+
12
+
13
+ # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
14
+ from modules import errors
15
+
16
+ TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
17
+
18
+ def encode(*args):
19
+ out = _codecs.encode(*args)
20
+ return out
21
+
22
+
23
+ class RestrictedUnpickler(pickle.Unpickler):
24
+ extra_handler = None
25
+
26
+ def persistent_load(self, saved_id):
27
+ assert saved_id[0] == 'storage'
28
+
29
+ try:
30
+ return TypedStorage(_internal=True)
31
+ except TypeError:
32
+ return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
33
+
34
+ def find_class(self, module, name):
35
+ if self.extra_handler is not None:
36
+ res = self.extra_handler(module, name)
37
+ if res is not None:
38
+ return res
39
+
40
+ if module == 'collections' and name == 'OrderedDict':
41
+ return getattr(collections, name)
42
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
43
+ return getattr(torch._utils, name)
44
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
45
+ return getattr(torch, name)
46
+ if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
47
+ return getattr(torch.nn.modules.container, name)
48
+ if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
49
+ return getattr(numpy.core.multiarray, name)
50
+ if module == 'numpy' and name in ['dtype', 'ndarray']:
51
+ return getattr(numpy, name)
52
+ if module == '_codecs' and name == 'encode':
53
+ return encode
54
+ if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
55
+ import pytorch_lightning.callbacks
56
+ return pytorch_lightning.callbacks.model_checkpoint
57
+ if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
58
+ import pytorch_lightning.callbacks.model_checkpoint
59
+ return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
60
+ if module == "__builtin__" and name == 'set':
61
+ return set
62
+
63
+ # Forbid everything else.
64
+ raise Exception(f"global '{module}/{name}' is forbidden")
65
+
66
+
67
+ # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
68
+ allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
69
+ data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
70
+
71
+ def check_zip_filenames(filename, names):
72
+ for name in names:
73
+ if allowed_zip_names_re.match(name):
74
+ continue
75
+
76
+ raise Exception(f"bad file inside {filename}: {name}")
77
+
78
+
79
+ def check_pt(filename, extra_handler):
80
+ try:
81
+
82
+ # new pytorch format is a zip file
83
+ with zipfile.ZipFile(filename) as z:
84
+ check_zip_filenames(filename, z.namelist())
85
+
86
+ # find filename of data.pkl in zip file: '<directory name>/data.pkl'
87
+ data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
88
+ if len(data_pkl_filenames) == 0:
89
+ raise Exception(f"data.pkl not found in {filename}")
90
+ if len(data_pkl_filenames) > 1:
91
+ raise Exception(f"Multiple data.pkl found in {filename}")
92
+ with z.open(data_pkl_filenames[0]) as file:
93
+ unpickler = RestrictedUnpickler(file)
94
+ unpickler.extra_handler = extra_handler
95
+ unpickler.load()
96
+
97
+ except zipfile.BadZipfile:
98
+
99
+ # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
100
+ with open(filename, "rb") as file:
101
+ unpickler = RestrictedUnpickler(file)
102
+ unpickler.extra_handler = extra_handler
103
+ for _ in range(5):
104
+ unpickler.load()
105
+
106
+
107
+ def load(filename, *args, **kwargs):
108
+ return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
109
+
110
+
111
+ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
112
+ """
113
+ this function is intended to be used by extensions that want to load models with
114
+ some extra classes in them that the usual unpickler would find suspicious.
115
+
116
+ Use the extra_handler argument to specify a function that takes module and field name as text,
117
+ and returns that field's value:
118
+
119
+ ```python
120
+ def extra(module, name):
121
+ if module == 'collections' and name == 'OrderedDict':
122
+ return collections.OrderedDict
123
+
124
+ return None
125
+
126
+ safe.load_with_extra('model.pt', extra_handler=extra)
127
+ ```
128
+
129
+ The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
130
+ definitely unsafe.
131
+ """
132
+
133
+ from modules import shared
134
+
135
+ try:
136
+ if not shared.cmd_opts.disable_safe_unpickle:
137
+ check_pt(filename, extra_handler)
138
+
139
+ except pickle.UnpicklingError:
140
+ errors.report(
141
+ f"Error verifying pickled file from {filename}\n"
142
+ "-----> !!!! The file is most likely corrupted !!!! <-----\n"
143
+ "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
144
+ exc_info=True,
145
+ )
146
+ return None
147
+ except Exception:
148
+ errors.report(
149
+ f"Error verifying pickled file from {filename}\n"
150
+ f"The file may be malicious, so the program is not going to read it.\n"
151
+ f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
152
+ exc_info=True,
153
+ )
154
+ return None
155
+
156
+ return unsafe_torch_load(filename, *args, **kwargs)
157
+
158
+
159
+ class Extra:
160
+ """
161
+ A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
162
+ (because it's not your code making the torch.load call). The intended use is like this:
163
+
164
+ ```
165
+ import torch
166
+ from modules import safe
167
+
168
+ def handler(module, name):
169
+ if module == 'torch' and name in ['float64', 'float16']:
170
+ return getattr(torch, name)
171
+
172
+ return None
173
+
174
+ with safe.Extra(handler):
175
+ x = torch.load('model.pt')
176
+ ```
177
+ """
178
+
179
+ def __init__(self, handler):
180
+ self.handler = handler
181
+
182
+ def __enter__(self):
183
+ global global_extra_handler
184
+
185
+ assert global_extra_handler is None, 'already inside an Extra() block'
186
+ global_extra_handler = self.handler
187
+
188
+ def __exit__(self, exc_type, exc_val, exc_tb):
189
+ global global_extra_handler
190
+
191
+ global_extra_handler = None
192
+
193
+
194
+ unsafe_torch_load = torch.load
195
+ torch.load = load
196
+ global_extra_handler = None
modules/script_callbacks.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ from collections import namedtuple
4
+ from typing import Optional, Dict, Any
5
+
6
+ from fastapi import FastAPI
7
+ from gradio import Blocks
8
+
9
+ from modules import errors, timer
10
+
11
+
12
+ def report_exception(c, job):
13
+ errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)
14
+
15
+
16
+ class ImageSaveParams:
17
+ def __init__(self, image, p, filename, pnginfo):
18
+ self.image = image
19
+ """the PIL image itself"""
20
+
21
+ self.p = p
22
+ """p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
23
+
24
+ self.filename = filename
25
+ """name of file that the image would be saved to"""
26
+
27
+ self.pnginfo = pnginfo
28
+ """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
29
+
30
+
31
+ class CFGDenoiserParams:
32
+ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
33
+ self.x = x
34
+ """Latent image representation in the process of being denoised"""
35
+
36
+ self.image_cond = image_cond
37
+ """Conditioning image"""
38
+
39
+ self.sigma = sigma
40
+ """Current sigma noise step value"""
41
+
42
+ self.sampling_step = sampling_step
43
+ """Current Sampling step number"""
44
+
45
+ self.total_sampling_steps = total_sampling_steps
46
+ """Total number of sampling steps planned"""
47
+
48
+ self.text_cond = text_cond
49
+ """ Encoder hidden states of text conditioning from prompt"""
50
+
51
+ self.text_uncond = text_uncond
52
+ """ Encoder hidden states of text conditioning from negative prompt"""
53
+
54
+
55
+ class CFGDenoisedParams:
56
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
57
+ self.x = x
58
+ """Latent image representation in the process of being denoised"""
59
+
60
+ self.sampling_step = sampling_step
61
+ """Current Sampling step number"""
62
+
63
+ self.total_sampling_steps = total_sampling_steps
64
+ """Total number of sampling steps planned"""
65
+
66
+ self.inner_model = inner_model
67
+ """Inner model reference used for denoising"""
68
+
69
+
70
+ class AfterCFGCallbackParams:
71
+ def __init__(self, x, sampling_step, total_sampling_steps):
72
+ self.x = x
73
+ """Latent image representation in the process of being denoised"""
74
+
75
+ self.sampling_step = sampling_step
76
+ """Current Sampling step number"""
77
+
78
+ self.total_sampling_steps = total_sampling_steps
79
+ """Total number of sampling steps planned"""
80
+
81
+
82
+ class UiTrainTabParams:
83
+ def __init__(self, txt2img_preview_params):
84
+ self.txt2img_preview_params = txt2img_preview_params
85
+
86
+
87
+ class ImageGridLoopParams:
88
+ def __init__(self, imgs, cols, rows):
89
+ self.imgs = imgs
90
+ self.cols = cols
91
+ self.rows = rows
92
+
93
+
94
+ ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
95
+ callback_map = dict(
96
+ callbacks_app_started=[],
97
+ callbacks_model_loaded=[],
98
+ callbacks_ui_tabs=[],
99
+ callbacks_ui_train_tabs=[],
100
+ callbacks_ui_settings=[],
101
+ callbacks_before_image_saved=[],
102
+ callbacks_image_saved=[],
103
+ callbacks_cfg_denoiser=[],
104
+ callbacks_cfg_denoised=[],
105
+ callbacks_cfg_after_cfg=[],
106
+ callbacks_before_component=[],
107
+ callbacks_after_component=[],
108
+ callbacks_image_grid=[],
109
+ callbacks_infotext_pasted=[],
110
+ callbacks_script_unloaded=[],
111
+ callbacks_before_ui=[],
112
+ callbacks_on_reload=[],
113
+ callbacks_list_optimizers=[],
114
+ callbacks_list_unets=[],
115
+ )
116
+
117
+
118
+ def clear_callbacks():
119
+ for callback_list in callback_map.values():
120
+ callback_list.clear()
121
+
122
+
123
+ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
124
+ for c in callback_map['callbacks_app_started']:
125
+ try:
126
+ c.callback(demo, app)
127
+ timer.startup_timer.record(os.path.basename(c.script))
128
+ except Exception:
129
+ report_exception(c, 'app_started_callback')
130
+
131
+
132
+ def app_reload_callback():
133
+ for c in callback_map['callbacks_on_reload']:
134
+ try:
135
+ c.callback()
136
+ except Exception:
137
+ report_exception(c, 'callbacks_on_reload')
138
+
139
+
140
+ def model_loaded_callback(sd_model):
141
+ for c in callback_map['callbacks_model_loaded']:
142
+ try:
143
+ c.callback(sd_model)
144
+ except Exception:
145
+ report_exception(c, 'model_loaded_callback')
146
+
147
+
148
+ def ui_tabs_callback():
149
+ res = []
150
+
151
+ for c in callback_map['callbacks_ui_tabs']:
152
+ try:
153
+ res += c.callback() or []
154
+ except Exception:
155
+ report_exception(c, 'ui_tabs_callback')
156
+
157
+ return res
158
+
159
+
160
+ def ui_train_tabs_callback(params: UiTrainTabParams):
161
+ for c in callback_map['callbacks_ui_train_tabs']:
162
+ try:
163
+ c.callback(params)
164
+ except Exception:
165
+ report_exception(c, 'callbacks_ui_train_tabs')
166
+
167
+
168
+ def ui_settings_callback():
169
+ for c in callback_map['callbacks_ui_settings']:
170
+ try:
171
+ c.callback()
172
+ except Exception:
173
+ report_exception(c, 'ui_settings_callback')
174
+
175
+
176
+ def before_image_saved_callback(params: ImageSaveParams):
177
+ for c in callback_map['callbacks_before_image_saved']:
178
+ try:
179
+ c.callback(params)
180
+ except Exception:
181
+ report_exception(c, 'before_image_saved_callback')
182
+
183
+
184
+ def image_saved_callback(params: ImageSaveParams):
185
+ for c in callback_map['callbacks_image_saved']:
186
+ try:
187
+ c.callback(params)
188
+ except Exception:
189
+ report_exception(c, 'image_saved_callback')
190
+
191
+
192
+ def cfg_denoiser_callback(params: CFGDenoiserParams):
193
+ for c in callback_map['callbacks_cfg_denoiser']:
194
+ try:
195
+ c.callback(params)
196
+ except Exception:
197
+ report_exception(c, 'cfg_denoiser_callback')
198
+
199
+
200
+ def cfg_denoised_callback(params: CFGDenoisedParams):
201
+ for c in callback_map['callbacks_cfg_denoised']:
202
+ try:
203
+ c.callback(params)
204
+ except Exception:
205
+ report_exception(c, 'cfg_denoised_callback')
206
+
207
+
208
+ def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
209
+ for c in callback_map['callbacks_cfg_after_cfg']:
210
+ try:
211
+ c.callback(params)
212
+ except Exception:
213
+ report_exception(c, 'cfg_after_cfg_callback')
214
+
215
+
216
+ def before_component_callback(component, **kwargs):
217
+ for c in callback_map['callbacks_before_component']:
218
+ try:
219
+ c.callback(component, **kwargs)
220
+ except Exception:
221
+ report_exception(c, 'before_component_callback')
222
+
223
+
224
+ def after_component_callback(component, **kwargs):
225
+ for c in callback_map['callbacks_after_component']:
226
+ try:
227
+ c.callback(component, **kwargs)
228
+ except Exception:
229
+ report_exception(c, 'after_component_callback')
230
+
231
+
232
+ def image_grid_callback(params: ImageGridLoopParams):
233
+ for c in callback_map['callbacks_image_grid']:
234
+ try:
235
+ c.callback(params)
236
+ except Exception:
237
+ report_exception(c, 'image_grid')
238
+
239
+
240
+ def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
241
+ for c in callback_map['callbacks_infotext_pasted']:
242
+ try:
243
+ c.callback(infotext, params)
244
+ except Exception:
245
+ report_exception(c, 'infotext_pasted')
246
+
247
+
248
+ def script_unloaded_callback():
249
+ for c in reversed(callback_map['callbacks_script_unloaded']):
250
+ try:
251
+ c.callback()
252
+ except Exception:
253
+ report_exception(c, 'script_unloaded')
254
+
255
+
256
+ def before_ui_callback():
257
+ for c in reversed(callback_map['callbacks_before_ui']):
258
+ try:
259
+ c.callback()
260
+ except Exception:
261
+ report_exception(c, 'before_ui')
262
+
263
+
264
+ def list_optimizers_callback():
265
+ res = []
266
+
267
+ for c in callback_map['callbacks_list_optimizers']:
268
+ try:
269
+ c.callback(res)
270
+ except Exception:
271
+ report_exception(c, 'list_optimizers')
272
+
273
+ return res
274
+
275
+
276
+ def list_unets_callback():
277
+ res = []
278
+
279
+ for c in callback_map['callbacks_list_unets']:
280
+ try:
281
+ c.callback(res)
282
+ except Exception:
283
+ report_exception(c, 'list_unets')
284
+
285
+ return res
286
+
287
+
288
+ def add_callback(callbacks, fun):
289
+ stack = [x for x in inspect.stack() if x.filename != __file__]
290
+ filename = stack[0].filename if stack else 'unknown file'
291
+
292
+ callbacks.append(ScriptCallback(filename, fun))
293
+
294
+
295
+ def remove_current_script_callbacks():
296
+ stack = [x for x in inspect.stack() if x.filename != __file__]
297
+ filename = stack[0].filename if stack else 'unknown file'
298
+ if filename == 'unknown file':
299
+ return
300
+ for callback_list in callback_map.values():
301
+ for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
302
+ callback_list.remove(callback_to_remove)
303
+
304
+
305
+ def remove_callbacks_for_function(callback_func):
306
+ for callback_list in callback_map.values():
307
+ for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
308
+ callback_list.remove(callback_to_remove)
309
+
310
+
311
+ def on_app_started(callback):
312
+ """register a function to be called when the webui started, the gradio `Block` component and
313
+ fastapi `FastAPI` object are passed as the arguments"""
314
+ add_callback(callback_map['callbacks_app_started'], callback)
315
+
316
+
317
+ def on_before_reload(callback):
318
+ """register a function to be called just before the server reloads."""
319
+ add_callback(callback_map['callbacks_on_reload'], callback)
320
+
321
+
322
+ def on_model_loaded(callback):
323
+ """register a function to be called when the stable diffusion model is created; the model is
324
+ passed as an argument; this function is also called when the script is reloaded. """
325
+ add_callback(callback_map['callbacks_model_loaded'], callback)
326
+
327
+
328
+ def on_ui_tabs(callback):
329
+ """register a function to be called when the UI is creating new tabs.
330
+ The function must either return a None, which means no new tabs to be added, or a list, where
331
+ each element is a tuple:
332
+ (gradio_component, title, elem_id)
333
+
334
+ gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
335
+ title is tab text displayed to user in the UI
336
+ elem_id is HTML id for the tab
337
+ """
338
+ add_callback(callback_map['callbacks_ui_tabs'], callback)
339
+
340
+
341
+ def on_ui_train_tabs(callback):
342
+ """register a function to be called when the UI is creating new tabs for the train tab.
343
+ Create your new tabs with gr.Tab.
344
+ """
345
+ add_callback(callback_map['callbacks_ui_train_tabs'], callback)
346
+
347
+
348
+ def on_ui_settings(callback):
349
+ """register a function to be called before UI settings are populated; add your settings
350
+ by using shared.opts.add_option(shared.OptionInfo(...)) """
351
+ add_callback(callback_map['callbacks_ui_settings'], callback)
352
+
353
+
354
+ def on_before_image_saved(callback):
355
+ """register a function to be called before an image is saved to a file.
356
+ The callback is called with one argument:
357
+ - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
358
+ """
359
+ add_callback(callback_map['callbacks_before_image_saved'], callback)
360
+
361
+
362
+ def on_image_saved(callback):
363
+ """register a function to be called after an image is saved to a file.
364
+ The callback is called with one argument:
365
+ - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
366
+ """
367
+ add_callback(callback_map['callbacks_image_saved'], callback)
368
+
369
+
370
+ def on_cfg_denoiser(callback):
371
+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
372
+ The callback is called with one argument:
373
+ - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
374
+ """
375
+ add_callback(callback_map['callbacks_cfg_denoiser'], callback)
376
+
377
+
378
+ def on_cfg_denoised(callback):
379
+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
380
+ The callback is called with one argument:
381
+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
382
+ """
383
+ add_callback(callback_map['callbacks_cfg_denoised'], callback)
384
+
385
+
386
+ def on_cfg_after_cfg(callback):
387
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
388
+ The callback is called with one argument:
389
+ - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
390
+ """
391
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
392
+
393
+
394
+ def on_before_component(callback):
395
+ """register a function to be called before a component is created.
396
+ The callback is called with arguments:
397
+ - component - gradio component that is about to be created.
398
+ - **kwargs - args to gradio.components.IOComponent.__init__ function
399
+
400
+ Use elem_id/label fields of kwargs to figure out which component it is.
401
+ This can be useful to inject your own components somewhere in the middle of vanilla UI.
402
+ """
403
+ add_callback(callback_map['callbacks_before_component'], callback)
404
+
405
+
406
+ def on_after_component(callback):
407
+ """register a function to be called after a component is created. See on_before_component for more."""
408
+ add_callback(callback_map['callbacks_after_component'], callback)
409
+
410
+
411
+ def on_image_grid(callback):
412
+ """register a function to be called before making an image grid.
413
+ The callback is called with one argument:
414
+ - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
415
+ """
416
+ add_callback(callback_map['callbacks_image_grid'], callback)
417
+
418
+
419
+ def on_infotext_pasted(callback):
420
+ """register a function to be called before applying an infotext.
421
+ The callback is called with two arguments:
422
+ - infotext: str - raw infotext.
423
+ - result: Dict[str, any] - parsed infotext parameters.
424
+ """
425
+ add_callback(callback_map['callbacks_infotext_pasted'], callback)
426
+
427
+
428
+ def on_script_unloaded(callback):
429
+ """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
430
+ the script did should be reverted here"""
431
+
432
+ add_callback(callback_map['callbacks_script_unloaded'], callback)
433
+
434
+
435
+ def on_before_ui(callback):
436
+ """register a function to be called before the UI is created."""
437
+
438
+ add_callback(callback_map['callbacks_before_ui'], callback)
439
+
440
+
441
+ def on_list_optimizers(callback):
442
+ """register a function to be called when UI is making a list of cross attention optimization options.
443
+ The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
444
+ to it."""
445
+
446
+ add_callback(callback_map['callbacks_list_optimizers'], callback)
447
+
448
+
449
+ def on_list_unets(callback):
450
+ """register a function to be called when UI is making a list of alternative options for unet.
451
+ The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
452
+
453
+ add_callback(callback_map['callbacks_list_unets'], callback)
modules/script_loading.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib.util
3
+
4
+ from modules import errors
5
+
6
+
7
+ def load_module(path):
8
+ module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
9
+ module = importlib.util.module_from_spec(module_spec)
10
+ module_spec.loader.exec_module(module)
11
+
12
+ return module
13
+
14
+
15
+ def preload_extensions(extensions_dir, parser, extension_list=None):
16
+ if not os.path.isdir(extensions_dir):
17
+ return
18
+
19
+ extensions = extension_list if extension_list is not None else os.listdir(extensions_dir)
20
+ for dirname in sorted(extensions):
21
+ preload_script = os.path.join(extensions_dir, dirname, "preload.py")
22
+ if not os.path.isfile(preload_script):
23
+ continue
24
+
25
+ try:
26
+ module = load_module(preload_script)
27
+ if hasattr(module, 'preload'):
28
+ module.preload(parser)
29
+
30
+ except Exception:
31
+ errors.report(f"Error running preload() for {preload_script}", exc_info=True)
modules/scripts.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+ import inspect
5
+ from collections import namedtuple
6
+
7
+ import gradio as gr
8
+
9
+ from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
10
+
11
+ AlwaysVisible = object()
12
+
13
+
14
+ class PostprocessImageArgs:
15
+ def __init__(self, image):
16
+ self.image = image
17
+
18
+
19
+ class PostprocessBatchListArgs:
20
+ def __init__(self, images):
21
+ self.images = images
22
+
23
+
24
+ class Script:
25
+ name = None
26
+ """script's internal name derived from title"""
27
+
28
+ section = None
29
+ """name of UI section that the script's controls will be placed into"""
30
+
31
+ filename = None
32
+ args_from = None
33
+ args_to = None
34
+ alwayson = False
35
+
36
+ is_txt2img = False
37
+ is_img2img = False
38
+
39
+ group = None
40
+ """A gr.Group component that has all script's UI inside it"""
41
+
42
+ infotext_fields = None
43
+ """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
44
+ parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
45
+ """
46
+
47
+ paste_field_names = None
48
+ """if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
49
+ various "Send to <X>" buttons when clicked
50
+ """
51
+
52
+ api_info = None
53
+ """Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
54
+
55
+ def title(self):
56
+ """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
57
+
58
+ raise NotImplementedError()
59
+
60
+ def ui(self, is_img2img):
61
+ """this function should create gradio UI elements. See https://gradio.app/docs/#components
62
+ The return value should be an array of all components that are used in processing.
63
+ Values of those returned components will be passed to run() and process() functions.
64
+ """
65
+
66
+ pass
67
+
68
+ def show(self, is_img2img):
69
+ """
70
+ is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
71
+
72
+ This function should return:
73
+ - False if the script should not be shown in UI at all
74
+ - True if the script should be shown in UI if it's selected in the scripts dropdown
75
+ - script.AlwaysVisible if the script should be shown in UI at all times
76
+ """
77
+
78
+ return True
79
+
80
+ def run(self, p, *args):
81
+ """
82
+ This function is called if the script has been selected in the script dropdown.
83
+ It must do all processing and return the Processed object with results, same as
84
+ one returned by processing.process_images.
85
+
86
+ Usually the processing is done by calling the processing.process_images function.
87
+
88
+ args contains all values returned by components from ui()
89
+ """
90
+
91
+ pass
92
+
93
+ def before_process(self, p, *args):
94
+ """
95
+ This function is called very early before processing begins for AlwaysVisible scripts.
96
+ You can modify the processing object (p) here, inject hooks, etc.
97
+ args contains all values returned by components from ui()
98
+ """
99
+
100
+ pass
101
+
102
+ def process(self, p, *args):
103
+ """
104
+ This function is called before processing begins for AlwaysVisible scripts.
105
+ You can modify the processing object (p) here, inject hooks, etc.
106
+ args contains all values returned by components from ui()
107
+ """
108
+
109
+ pass
110
+
111
+ def before_process_batch(self, p, *args, **kwargs):
112
+ """
113
+ Called before extra networks are parsed from the prompt, so you can add
114
+ new extra network keywords to the prompt with this callback.
115
+
116
+ **kwargs will have those items:
117
+ - batch_number - index of current batch, from 0 to number of batches-1
118
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
119
+ - seeds - list of seeds for current batch
120
+ - subseeds - list of subseeds for current batch
121
+ """
122
+
123
+ pass
124
+
125
+ def after_extra_networks_activate(self, p, *args, **kwargs):
126
+ """
127
+ Called after extra networks activation, before conds calculation
128
+ allow modification of the network after extra networks activation been applied
129
+ won't be call if p.disable_extra_networks
130
+
131
+ **kwargs will have those items:
132
+ - batch_number - index of current batch, from 0 to number of batches-1
133
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
134
+ - seeds - list of seeds for current batch
135
+ - subseeds - list of subseeds for current batch
136
+ - extra_network_data - list of ExtraNetworkParams for current stage
137
+ """
138
+ pass
139
+
140
+ def process_batch(self, p, *args, **kwargs):
141
+ """
142
+ Same as process(), but called for every batch.
143
+
144
+ **kwargs will have those items:
145
+ - batch_number - index of current batch, from 0 to number of batches-1
146
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
147
+ - seeds - list of seeds for current batch
148
+ - subseeds - list of subseeds for current batch
149
+ """
150
+
151
+ pass
152
+
153
+ def postprocess_batch(self, p, *args, **kwargs):
154
+ """
155
+ Same as process_batch(), but called for every batch after it has been generated.
156
+
157
+ **kwargs will have same items as process_batch, and also:
158
+ - batch_number - index of current batch, from 0 to number of batches-1
159
+ - images - torch tensor with all generated images, with values ranging from 0 to 1;
160
+ """
161
+
162
+ pass
163
+
164
+ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
165
+ """
166
+ Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
167
+ This is useful when you want to update the entire batch instead of individual images.
168
+
169
+ You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
170
+ If the number of images is different from the batch size when returning,
171
+ then the script has the responsibility to also update the following attributes in the processing object (p):
172
+ - p.prompts
173
+ - p.negative_prompts
174
+ - p.seeds
175
+ - p.subseeds
176
+
177
+ **kwargs will have same items as process_batch, and also:
178
+ - batch_number - index of current batch, from 0 to number of batches-1
179
+ """
180
+
181
+ pass
182
+
183
+ def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
184
+ """
185
+ Called for every image after it has been generated.
186
+ """
187
+
188
+ pass
189
+
190
+ def postprocess(self, p, processed, *args):
191
+ """
192
+ This function is called after processing ends for AlwaysVisible scripts.
193
+ args contains all values returned by components from ui()
194
+ """
195
+
196
+ pass
197
+
198
+ def before_component(self, component, **kwargs):
199
+ """
200
+ Called before a component is created.
201
+ Use elem_id/label fields of kwargs to figure out which component it is.
202
+ This can be useful to inject your own components somewhere in the middle of vanilla UI.
203
+ You can return created components in the ui() function to add them to the list of arguments for your processing functions
204
+ """
205
+
206
+ pass
207
+
208
+ def after_component(self, component, **kwargs):
209
+ """
210
+ Called after a component is created. Same as above.
211
+ """
212
+
213
+ pass
214
+
215
+ def describe(self):
216
+ """unused"""
217
+ return ""
218
+
219
+ def elem_id(self, item_id):
220
+ """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
221
+
222
+ need_tabname = self.show(True) == self.show(False)
223
+ tabkind = 'img2img' if self.is_img2img else 'txt2txt'
224
+ tabname = f"{tabkind}_" if need_tabname else ""
225
+ title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
226
+
227
+ return f'script_{tabname}{title}_{item_id}'
228
+
229
+ def before_hr(self, p, *args):
230
+ """
231
+ This function is called before hires fix start.
232
+ """
233
+ pass
234
+
235
+ current_basedir = paths.script_path
236
+
237
+
238
+ def basedir():
239
+ """returns the base directory for the current script. For scripts in the main scripts directory,
240
+ this is the main directory (where webui.py resides), and for scripts in extensions directory
241
+ (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
242
+ """
243
+ return current_basedir
244
+
245
+
246
+ ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
247
+
248
+ scripts_data = []
249
+ postprocessing_scripts_data = []
250
+ ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
251
+
252
+
253
+ def list_scripts(scriptdirname, extension):
254
+ scripts_list = []
255
+
256
+ basedir = os.path.join(paths.script_path, scriptdirname)
257
+ if os.path.exists(basedir):
258
+ for filename in sorted(os.listdir(basedir)):
259
+ scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
260
+
261
+ for ext in extensions.active():
262
+ scripts_list += ext.list_files(scriptdirname, extension)
263
+
264
+ scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
265
+
266
+ return scripts_list
267
+
268
+
269
+ def list_files_with_name(filename):
270
+ res = []
271
+
272
+ dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
273
+
274
+ for dirpath in dirs:
275
+ if not os.path.isdir(dirpath):
276
+ continue
277
+
278
+ path = os.path.join(dirpath, filename)
279
+ if os.path.isfile(path):
280
+ res.append(path)
281
+
282
+ return res
283
+
284
+
285
+ def load_scripts():
286
+ global current_basedir
287
+ scripts_data.clear()
288
+ postprocessing_scripts_data.clear()
289
+ script_callbacks.clear_callbacks()
290
+
291
+ scripts_list = list_scripts("scripts", ".py")
292
+
293
+ syspath = sys.path
294
+
295
+ def register_scripts_from_module(module):
296
+ for script_class in module.__dict__.values():
297
+ if not inspect.isclass(script_class):
298
+ continue
299
+
300
+ if issubclass(script_class, Script):
301
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
302
+ elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
303
+ postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
304
+
305
+ def orderby(basedir):
306
+ # 1st webui, 2nd extensions-builtin, 3rd extensions
307
+ priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
308
+ for key in priority:
309
+ if basedir.startswith(key):
310
+ return priority[key]
311
+ return 9999
312
+
313
+ for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
314
+ try:
315
+ if scriptfile.basedir != paths.script_path:
316
+ sys.path = [scriptfile.basedir] + sys.path
317
+ current_basedir = scriptfile.basedir
318
+
319
+ script_module = script_loading.load_module(scriptfile.path)
320
+ register_scripts_from_module(script_module)
321
+
322
+ except Exception:
323
+ errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True)
324
+
325
+ finally:
326
+ sys.path = syspath
327
+ current_basedir = paths.script_path
328
+ timer.startup_timer.record(scriptfile.filename)
329
+
330
+ global scripts_txt2img, scripts_img2img, scripts_postproc
331
+
332
+ scripts_txt2img = ScriptRunner()
333
+ scripts_img2img = ScriptRunner()
334
+ scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
335
+
336
+
337
+ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
338
+ try:
339
+ return func(*args, **kwargs)
340
+ except Exception:
341
+ errors.report(f"Error calling: {filename}/{funcname}", exc_info=True)
342
+
343
+ return default
344
+
345
+
346
+ class ScriptRunner:
347
+ def __init__(self):
348
+ self.scripts = []
349
+ self.selectable_scripts = []
350
+ self.alwayson_scripts = []
351
+ self.titles = []
352
+ self.infotext_fields = []
353
+ self.paste_field_names = []
354
+ self.inputs = [None]
355
+
356
+ def initialize_scripts(self, is_img2img):
357
+ from modules import scripts_auto_postprocessing
358
+
359
+ self.scripts.clear()
360
+ self.alwayson_scripts.clear()
361
+ self.selectable_scripts.clear()
362
+
363
+ auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
364
+
365
+ for script_data in auto_processing_scripts + scripts_data:
366
+ script = script_data.script_class()
367
+ script.filename = script_data.path
368
+ script.is_txt2img = not is_img2img
369
+ script.is_img2img = is_img2img
370
+
371
+ visibility = script.show(script.is_img2img)
372
+
373
+ if visibility == AlwaysVisible:
374
+ self.scripts.append(script)
375
+ self.alwayson_scripts.append(script)
376
+ script.alwayson = True
377
+
378
+ elif visibility:
379
+ self.scripts.append(script)
380
+ self.selectable_scripts.append(script)
381
+
382
+ def create_script_ui(self, script):
383
+ import modules.api.models as api_models
384
+
385
+ script.args_from = len(self.inputs)
386
+ script.args_to = len(self.inputs)
387
+
388
+ controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
389
+
390
+ if controls is None:
391
+ return
392
+
393
+ script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
394
+ api_args = []
395
+
396
+ for control in controls:
397
+ control.custom_script_source = os.path.basename(script.filename)
398
+
399
+ arg_info = api_models.ScriptArg(label=control.label or "")
400
+
401
+ for field in ("value", "minimum", "maximum", "step", "choices"):
402
+ v = getattr(control, field, None)
403
+ if v is not None:
404
+ setattr(arg_info, field, v)
405
+
406
+ api_args.append(arg_info)
407
+
408
+ script.api_info = api_models.ScriptInfo(
409
+ name=script.name,
410
+ is_img2img=script.is_img2img,
411
+ is_alwayson=script.alwayson,
412
+ args=api_args,
413
+ )
414
+
415
+ if script.infotext_fields is not None:
416
+ self.infotext_fields += script.infotext_fields
417
+
418
+ if script.paste_field_names is not None:
419
+ self.paste_field_names += script.paste_field_names
420
+
421
+ self.inputs += controls
422
+ script.args_to = len(self.inputs)
423
+
424
+ def setup_ui_for_section(self, section, scriptlist=None):
425
+ if scriptlist is None:
426
+ scriptlist = self.alwayson_scripts
427
+
428
+ for script in scriptlist:
429
+ if script.alwayson and script.section != section:
430
+ continue
431
+
432
+ with gr.Group(visible=script.alwayson) as group:
433
+ self.create_script_ui(script)
434
+
435
+ script.group = group
436
+
437
+ def prepare_ui(self):
438
+ self.inputs = [None]
439
+
440
+ def setup_ui(self):
441
+ self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
442
+
443
+ self.setup_ui_for_section(None)
444
+
445
+ dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
446
+ self.inputs[0] = dropdown
447
+
448
+ self.setup_ui_for_section(None, self.selectable_scripts)
449
+
450
+
451
+ def select_script(script_index):
452
+ selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
453
+
454
+ return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
455
+
456
+ def init_field(title):
457
+ """called when an initial value is set from ui-config.json to show script's UI components"""
458
+
459
+ if title == 'None':
460
+ return
461
+
462
+ script_index = self.titles.index(title)
463
+ self.selectable_scripts[script_index].group.visible = True
464
+
465
+ dropdown.init_field = init_field
466
+
467
+ dropdown.change(
468
+ fn=select_script,
469
+ inputs=[dropdown],
470
+ outputs=[script.group for script in self.selectable_scripts]
471
+ )
472
+
473
+ self.script_load_ctr = 0
474
+
475
+ def onload_script_visibility(params):
476
+ title = params.get('Script', None)
477
+ if title:
478
+ title_index = self.titles.index(title)
479
+ visibility = title_index == self.script_load_ctr
480
+ self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
481
+ return gr.update(visible=visibility)
482
+ else:
483
+ return gr.update(visible=False)
484
+
485
+ self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
486
+ self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
487
+
488
+ return self.inputs
489
+
490
+ def run(self, p, *args):
491
+ script_index = args[0]
492
+
493
+ if script_index == 0:
494
+ return None
495
+
496
+ script = self.selectable_scripts[script_index-1]
497
+
498
+ if script is None:
499
+ return None
500
+
501
+ script_args = args[script.args_from:script.args_to]
502
+ processed = script.run(p, *script_args)
503
+
504
+ shared.total_tqdm.clear()
505
+
506
+ return processed
507
+
508
+ def before_process(self, p):
509
+ for script in self.alwayson_scripts:
510
+ try:
511
+ script_args = p.script_args[script.args_from:script.args_to]
512
+ script.before_process(p, *script_args)
513
+ except Exception:
514
+ errors.report(f"Error running before_process: {script.filename}", exc_info=True)
515
+
516
+ def process(self, p):
517
+ for script in self.alwayson_scripts:
518
+ try:
519
+ script_args = p.script_args[script.args_from:script.args_to]
520
+ script.process(p, *script_args)
521
+ except Exception:
522
+ errors.report(f"Error running process: {script.filename}", exc_info=True)
523
+
524
+ def before_process_batch(self, p, **kwargs):
525
+ for script in self.alwayson_scripts:
526
+ try:
527
+ script_args = p.script_args[script.args_from:script.args_to]
528
+ script.before_process_batch(p, *script_args, **kwargs)
529
+ except Exception:
530
+ errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
531
+
532
+ def after_extra_networks_activate(self, p, **kwargs):
533
+ for script in self.alwayson_scripts:
534
+ try:
535
+ script_args = p.script_args[script.args_from:script.args_to]
536
+ script.after_extra_networks_activate(p, *script_args, **kwargs)
537
+ except Exception:
538
+ errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
539
+
540
+ def process_batch(self, p, **kwargs):
541
+ for script in self.alwayson_scripts:
542
+ try:
543
+ script_args = p.script_args[script.args_from:script.args_to]
544
+ script.process_batch(p, *script_args, **kwargs)
545
+ except Exception:
546
+ errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
547
+
548
+ def postprocess(self, p, processed):
549
+ for script in self.alwayson_scripts:
550
+ try:
551
+ script_args = p.script_args[script.args_from:script.args_to]
552
+ script.postprocess(p, processed, *script_args)
553
+ except Exception:
554
+ errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
555
+
556
+ def postprocess_batch(self, p, images, **kwargs):
557
+ for script in self.alwayson_scripts:
558
+ try:
559
+ script_args = p.script_args[script.args_from:script.args_to]
560
+ script.postprocess_batch(p, *script_args, images=images, **kwargs)
561
+ except Exception:
562
+ errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
563
+
564
+ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
565
+ for script in self.alwayson_scripts:
566
+ try:
567
+ script_args = p.script_args[script.args_from:script.args_to]
568
+ script.postprocess_batch_list(p, pp, *script_args, **kwargs)
569
+ except Exception:
570
+ errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
571
+
572
+ def postprocess_image(self, p, pp: PostprocessImageArgs):
573
+ for script in self.alwayson_scripts:
574
+ try:
575
+ script_args = p.script_args[script.args_from:script.args_to]
576
+ script.postprocess_image(p, pp, *script_args)
577
+ except Exception:
578
+ errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
579
+
580
+ def before_component(self, component, **kwargs):
581
+ for script in self.scripts:
582
+ try:
583
+ script.before_component(component, **kwargs)
584
+ except Exception:
585
+ errors.report(f"Error running before_component: {script.filename}", exc_info=True)
586
+
587
+ def after_component(self, component, **kwargs):
588
+ for script in self.scripts:
589
+ try:
590
+ script.after_component(component, **kwargs)
591
+ except Exception:
592
+ errors.report(f"Error running after_component: {script.filename}", exc_info=True)
593
+
594
+ def reload_sources(self, cache):
595
+ for si, script in list(enumerate(self.scripts)):
596
+ args_from = script.args_from
597
+ args_to = script.args_to
598
+ filename = script.filename
599
+
600
+ module = cache.get(filename, None)
601
+ if module is None:
602
+ module = script_loading.load_module(script.filename)
603
+ cache[filename] = module
604
+
605
+ for script_class in module.__dict__.values():
606
+ if type(script_class) == type and issubclass(script_class, Script):
607
+ self.scripts[si] = script_class()
608
+ self.scripts[si].filename = filename
609
+ self.scripts[si].args_from = args_from
610
+ self.scripts[si].args_to = args_to
611
+
612
+
613
+ def before_hr(self, p):
614
+ for script in self.alwayson_scripts:
615
+ try:
616
+ script_args = p.script_args[script.args_from:script.args_to]
617
+ script.before_hr(p, *script_args)
618
+ except Exception:
619
+ errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
620
+
621
+
622
+ scripts_txt2img: ScriptRunner = None
623
+ scripts_img2img: ScriptRunner = None
624
+ scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
625
+ scripts_current: ScriptRunner = None
626
+
627
+
628
+ def reload_script_body_only():
629
+ cache = {}
630
+ scripts_txt2img.reload_sources(cache)
631
+ scripts_img2img.reload_sources(cache)
632
+
633
+
634
+ reload_scripts = load_scripts # compatibility alias
635
+
636
+
637
+ def add_classes_to_gradio_component(comp):
638
+ """
639
+ this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
640
+ """
641
+
642
+ comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
643
+
644
+ if getattr(comp, 'multiselect', False):
645
+ comp.elem_classes.append('multiselect')
646
+
647
+
648
+
649
+ def IOComponent_init(self, *args, **kwargs):
650
+ if scripts_current is not None:
651
+ scripts_current.before_component(self, **kwargs)
652
+
653
+ script_callbacks.before_component_callback(self, **kwargs)
654
+
655
+ res = original_IOComponent_init(self, *args, **kwargs)
656
+
657
+ add_classes_to_gradio_component(self)
658
+
659
+ script_callbacks.after_component_callback(self, **kwargs)
660
+
661
+ if scripts_current is not None:
662
+ scripts_current.after_component(self, **kwargs)
663
+
664
+ return res
665
+
666
+
667
+ original_IOComponent_init = gr.components.IOComponent.__init__
668
+ gr.components.IOComponent.__init__ = IOComponent_init
669
+
670
+
671
+ def BlockContext_init(self, *args, **kwargs):
672
+ res = original_BlockContext_init(self, *args, **kwargs)
673
+
674
+ add_classes_to_gradio_component(self)
675
+
676
+ return res
677
+
678
+
679
+ original_BlockContext_init = gr.blocks.BlockContext.__init__
680
+ gr.blocks.BlockContext.__init__ = BlockContext_init
modules/scripts_auto_postprocessing.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import scripts, scripts_postprocessing, shared
2
+
3
+
4
+ class ScriptPostprocessingForMainUI(scripts.Script):
5
+ def __init__(self, script_postproc):
6
+ self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
7
+ self.postprocessing_controls = None
8
+
9
+ def title(self):
10
+ return self.script.name
11
+
12
+ def show(self, is_img2img):
13
+ return scripts.AlwaysVisible
14
+
15
+ def ui(self, is_img2img):
16
+ self.postprocessing_controls = self.script.ui()
17
+ return self.postprocessing_controls.values()
18
+
19
+ def postprocess_image(self, p, script_pp, *args):
20
+ args_dict = dict(zip(self.postprocessing_controls, args))
21
+
22
+ pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
23
+ pp.info = {}
24
+ self.script.process(pp, **args_dict)
25
+ p.extra_generation_params.update(pp.info)
26
+ script_pp.image = pp.image
27
+
28
+
29
+ def create_auto_preprocessing_script_data():
30
+ from modules import scripts
31
+
32
+ res = []
33
+
34
+ for name in shared.opts.postprocessing_enable_in_main_ui:
35
+ script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
36
+ if script is None:
37
+ continue
38
+
39
+ constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
40
+ res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
41
+
42
+ return res
modules/scripts_postprocessing.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+
4
+ from modules import errors, shared
5
+
6
+
7
+ class PostprocessedImage:
8
+ def __init__(self, image):
9
+ self.image = image
10
+ self.info = {}
11
+
12
+
13
+ class ScriptPostprocessing:
14
+ filename = None
15
+ controls = None
16
+ args_from = None
17
+ args_to = None
18
+
19
+ order = 1000
20
+ """scripts will be ordred by this value in postprocessing UI"""
21
+
22
+ name = None
23
+ """this function should return the title of the script."""
24
+
25
+ group = None
26
+ """A gr.Group component that has all script's UI inside it"""
27
+
28
+ def ui(self):
29
+ """
30
+ This function should create gradio UI elements. See https://gradio.app/docs/#components
31
+ The return value should be a dictionary that maps parameter names to components used in processing.
32
+ Values of those components will be passed to process() function.
33
+ """
34
+
35
+ pass
36
+
37
+ def process(self, pp: PostprocessedImage, **args):
38
+ """
39
+ This function is called to postprocess the image.
40
+ args contains a dictionary with all values returned by components from ui()
41
+ """
42
+
43
+ pass
44
+
45
+ def image_changed(self):
46
+ pass
47
+
48
+
49
+
50
+
51
+ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
52
+ try:
53
+ res = func(*args, **kwargs)
54
+ return res
55
+ except Exception as e:
56
+ errors.display(e, f"calling {filename}/{funcname}")
57
+
58
+ return default
59
+
60
+
61
+ class ScriptPostprocessingRunner:
62
+ def __init__(self):
63
+ self.scripts = None
64
+ self.ui_created = False
65
+
66
+ def initialize_scripts(self, scripts_data):
67
+ self.scripts = []
68
+
69
+ for script_data in scripts_data:
70
+ script: ScriptPostprocessing = script_data.script_class()
71
+ script.filename = script_data.path
72
+
73
+ if script.name == "Simple Upscale":
74
+ continue
75
+
76
+ self.scripts.append(script)
77
+
78
+ def create_script_ui(self, script, inputs):
79
+ script.args_from = len(inputs)
80
+ script.args_to = len(inputs)
81
+
82
+ script.controls = wrap_call(script.ui, script.filename, "ui")
83
+
84
+ for control in script.controls.values():
85
+ control.custom_script_source = os.path.basename(script.filename)
86
+
87
+ inputs += list(script.controls.values())
88
+ script.args_to = len(inputs)
89
+
90
+ def scripts_in_preferred_order(self):
91
+ if self.scripts is None:
92
+ import modules.scripts
93
+ self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
94
+
95
+ scripts_order = shared.opts.postprocessing_operation_order
96
+
97
+ def script_score(name):
98
+ for i, possible_match in enumerate(scripts_order):
99
+ if possible_match == name:
100
+ return i
101
+
102
+ return len(self.scripts)
103
+
104
+ script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
105
+
106
+ return sorted(self.scripts, key=lambda x: script_scores[x.name])
107
+
108
+ def setup_ui(self):
109
+ inputs = []
110
+
111
+ for script in self.scripts_in_preferred_order():
112
+ with gr.Row() as group:
113
+ self.create_script_ui(script, inputs)
114
+
115
+ script.group = group
116
+
117
+ self.ui_created = True
118
+ return inputs
119
+
120
+ def run(self, pp: PostprocessedImage, args):
121
+ for script in self.scripts_in_preferred_order():
122
+ shared.state.job = script.name
123
+
124
+ script_args = args[script.args_from:script.args_to]
125
+
126
+ process_args = {}
127
+ for (name, _component), value in zip(script.controls.items(), script_args):
128
+ process_args[name] = value
129
+
130
+ script.process(pp, **process_args)
131
+
132
+ def create_args_for_run(self, scripts_args):
133
+ if not self.ui_created:
134
+ with gr.Blocks(analytics_enabled=False):
135
+ self.setup_ui()
136
+
137
+ scripts = self.scripts_in_preferred_order()
138
+ args = [None] * max([x.args_to for x in scripts])
139
+
140
+ for script in scripts:
141
+ script_args_dict = scripts_args.get(script.name, None)
142
+ if script_args_dict is not None:
143
+
144
+ for i, name in enumerate(script.controls):
145
+ args[script.args_from + i] = script_args_dict.get(name, None)
146
+
147
+ return args
148
+
149
+ def image_changed(self):
150
+ for script in self.scripts_in_preferred_order():
151
+ script.image_changed()
152
+
modules/sd_disable_initialization.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ldm.modules.encoders.modules
2
+ import open_clip
3
+ import torch
4
+ import transformers.utils.hub
5
+
6
+
7
+ class DisableInitialization:
8
+ """
9
+ When an object of this class enters a `with` block, it starts:
10
+ - preventing torch's layer initialization functions from working
11
+ - changes CLIP and OpenCLIP to not download model weights
12
+ - changes CLIP to not make requests to check if there is a new version of a file you already have
13
+
14
+ When it leaves the block, it reverts everything to how it was before.
15
+
16
+ Use it like this:
17
+ ```
18
+ with DisableInitialization():
19
+ do_things()
20
+ ```
21
+ """
22
+
23
+ def __init__(self, disable_clip=True):
24
+ self.replaced = []
25
+ self.disable_clip = disable_clip
26
+
27
+ def replace(self, obj, field, func):
28
+ original = getattr(obj, field, None)
29
+ if original is None:
30
+ return None
31
+
32
+ self.replaced.append((obj, field, original))
33
+ setattr(obj, field, func)
34
+
35
+ return original
36
+
37
+ def __enter__(self):
38
+ def do_nothing(*args, **kwargs):
39
+ pass
40
+
41
+ def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
42
+ return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
43
+
44
+ def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
45
+ res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
46
+ res.name_or_path = pretrained_model_name_or_path
47
+ return res
48
+
49
+ def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
50
+ args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
51
+ return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
52
+
53
+ def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
54
+
55
+ # this file is always 404, prevent making request
56
+ if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
57
+ return None
58
+
59
+ try:
60
+ res = original(url, *args, local_files_only=True, **kwargs)
61
+ if res is None:
62
+ res = original(url, *args, local_files_only=False, **kwargs)
63
+ return res
64
+ except Exception:
65
+ return original(url, *args, local_files_only=False, **kwargs)
66
+
67
+ def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
68
+ return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
69
+
70
+ def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
71
+ return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
72
+
73
+ def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
74
+ return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
75
+
76
+ self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
77
+ self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
78
+ self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
79
+
80
+ if self.disable_clip:
81
+ self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
82
+ self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
83
+ self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
84
+ self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
85
+ self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
86
+ self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
87
+
88
+ def __exit__(self, exc_type, exc_val, exc_tb):
89
+ for obj, field, original in self.replaced:
90
+ setattr(obj, field, original)
91
+
92
+ self.replaced.clear()
93
+
modules/sd_hijack.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.functional import silu
3
+ from types import MethodType
4
+
5
+ import modules.textual_inversion.textual_inversion
6
+ from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
7
+ from modules.hypernetworks import hypernetwork
8
+ from modules.shared import cmd_opts
9
+ from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
10
+
11
+ import ldm.modules.attention
12
+ import ldm.modules.diffusionmodules.model
13
+ import ldm.modules.diffusionmodules.openaimodel
14
+ import ldm.models.diffusion.ddim
15
+ import ldm.models.diffusion.plms
16
+ import ldm.modules.encoders.modules
17
+
18
+ import sgm.modules.attention
19
+ import sgm.modules.diffusionmodules.model
20
+ import sgm.modules.diffusionmodules.openaimodel
21
+ import sgm.modules.encoders.modules
22
+
23
+ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
24
+ diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
25
+ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
26
+
27
+ # new memory efficient cross attention blocks do not support hypernets and we already
28
+ # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
29
+ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
30
+ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
31
+
32
+ # silence new console spam from SD2
33
+ ldm.modules.attention.print = lambda *args: None
34
+ ldm.modules.diffusionmodules.model.print = lambda *args: None
35
+
36
+ optimizers = []
37
+ current_optimizer: sd_hijack_optimizations.SdOptimization = None
38
+
39
+
40
+ def list_optimizers():
41
+ new_optimizers = script_callbacks.list_optimizers_callback()
42
+
43
+ new_optimizers = [x for x in new_optimizers if x.is_available()]
44
+
45
+ new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
46
+
47
+ optimizers.clear()
48
+ optimizers.extend(new_optimizers)
49
+
50
+
51
+ def apply_optimizations(option=None):
52
+ global current_optimizer
53
+
54
+ undo_optimizations()
55
+
56
+ if len(optimizers) == 0:
57
+ # a script can access the model very early, and optimizations would not be filled by then
58
+ current_optimizer = None
59
+ return ''
60
+
61
+ ldm.modules.diffusionmodules.model.nonlinearity = silu
62
+ ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
63
+
64
+ sgm.modules.diffusionmodules.model.nonlinearity = silu
65
+ sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
66
+
67
+ if current_optimizer is not None:
68
+ current_optimizer.undo()
69
+ current_optimizer = None
70
+
71
+ selection = option or shared.opts.cross_attention_optimization
72
+ if selection == "Automatic" and len(optimizers) > 0:
73
+ matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
74
+ else:
75
+ matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
76
+
77
+ if selection == "None":
78
+ matching_optimizer = None
79
+ elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
80
+ matching_optimizer = None
81
+ elif matching_optimizer is None:
82
+ matching_optimizer = optimizers[0]
83
+
84
+ if matching_optimizer is not None:
85
+ print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
86
+ matching_optimizer.apply()
87
+ print("done.")
88
+ current_optimizer = matching_optimizer
89
+ return current_optimizer.name
90
+ else:
91
+ print("Disabling attention optimization")
92
+ return ''
93
+
94
+
95
+ def undo_optimizations():
96
+ ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
97
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
98
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
99
+
100
+ sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
101
+ sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
102
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
103
+
104
+
105
+ def fix_checkpoint():
106
+ """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
107
+ checkpoints to be added when not training (there's a warning)"""
108
+
109
+ pass
110
+
111
+
112
+ def weighted_loss(sd_model, pred, target, mean=True):
113
+ #Calculate the weight normally, but ignore the mean
114
+ loss = sd_model._old_get_loss(pred, target, mean=False)
115
+
116
+ #Check if we have weights available
117
+ weight = getattr(sd_model, '_custom_loss_weight', None)
118
+ if weight is not None:
119
+ loss *= weight
120
+
121
+ #Return the loss, as mean if specified
122
+ return loss.mean() if mean else loss
123
+
124
+ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
125
+ try:
126
+ #Temporarily append weights to a place accessible during loss calc
127
+ sd_model._custom_loss_weight = w
128
+
129
+ #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
130
+ #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
131
+ if not hasattr(sd_model, '_old_get_loss'):
132
+ sd_model._old_get_loss = sd_model.get_loss
133
+ sd_model.get_loss = MethodType(weighted_loss, sd_model)
134
+
135
+ #Run the standard forward function, but with the patched 'get_loss'
136
+ return sd_model.forward(x, c, *args, **kwargs)
137
+ finally:
138
+ try:
139
+ #Delete temporary weights if appended
140
+ del sd_model._custom_loss_weight
141
+ except AttributeError:
142
+ pass
143
+
144
+ #If we have an old loss function, reset the loss function to the original one
145
+ if hasattr(sd_model, '_old_get_loss'):
146
+ sd_model.get_loss = sd_model._old_get_loss
147
+ del sd_model._old_get_loss
148
+
149
+ def apply_weighted_forward(sd_model):
150
+ #Add new function 'weighted_forward' that can be called to calc weighted loss
151
+ sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
152
+
153
+ def undo_weighted_forward(sd_model):
154
+ try:
155
+ del sd_model.weighted_forward
156
+ except AttributeError:
157
+ pass
158
+
159
+
160
+ class StableDiffusionModelHijack:
161
+ fixes = None
162
+ layers = None
163
+ circular_enabled = False
164
+ clip = None
165
+ optimization_method = None
166
+
167
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
168
+
169
+ def __init__(self):
170
+ self.extra_generation_params = {}
171
+ self.comments = []
172
+
173
+ self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
174
+
175
+ def apply_optimizations(self, option=None):
176
+ try:
177
+ self.optimization_method = apply_optimizations(option)
178
+ except Exception as e:
179
+ errors.display(e, "applying cross attention optimization")
180
+ undo_optimizations()
181
+
182
+ def hijack(self, m):
183
+ conditioner = getattr(m, 'conditioner', None)
184
+ if conditioner:
185
+ text_cond_models = []
186
+
187
+ for i in range(len(conditioner.embedders)):
188
+ embedder = conditioner.embedders[i]
189
+ typename = type(embedder).__name__
190
+ if typename == 'FrozenOpenCLIPEmbedder':
191
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
192
+ conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
193
+ text_cond_models.append(conditioner.embedders[i])
194
+ if typename == 'FrozenCLIPEmbedder':
195
+ model_embeddings = embedder.transformer.text_model.embeddings
196
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
197
+ conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
198
+ text_cond_models.append(conditioner.embedders[i])
199
+ if typename == 'FrozenOpenCLIPEmbedder2':
200
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
201
+ conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
202
+ text_cond_models.append(conditioner.embedders[i])
203
+
204
+ if len(text_cond_models) == 1:
205
+ m.cond_stage_model = text_cond_models[0]
206
+ else:
207
+ m.cond_stage_model = conditioner
208
+
209
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
210
+ model_embeddings = m.cond_stage_model.roberta.embeddings
211
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
212
+ m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
213
+
214
+ elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
215
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
216
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
217
+ m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
218
+
219
+ elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
220
+ m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
221
+ m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
222
+
223
+ apply_weighted_forward(m)
224
+ if m.cond_stage_key == "edit":
225
+ sd_hijack_unet.hijack_ddpm_edit()
226
+
227
+ self.apply_optimizations()
228
+
229
+ self.clip = m.cond_stage_model
230
+
231
+ def flatten(el):
232
+ flattened = [flatten(children) for children in el.children()]
233
+ res = [el]
234
+ for c in flattened:
235
+ res += c
236
+ return res
237
+
238
+ self.layers = flatten(m)
239
+
240
+ if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
241
+ ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
242
+
243
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
244
+
245
+ def undo_hijack(self, m):
246
+ if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
247
+ m.cond_stage_model = m.cond_stage_model.wrapped
248
+
249
+ elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
250
+ m.cond_stage_model = m.cond_stage_model.wrapped
251
+
252
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
253
+ if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
254
+ model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
255
+ elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
256
+ m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
257
+ m.cond_stage_model = m.cond_stage_model.wrapped
258
+
259
+ undo_optimizations()
260
+ undo_weighted_forward(m)
261
+
262
+ self.apply_circular(False)
263
+ self.layers = None
264
+ self.clip = None
265
+
266
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
267
+
268
+ def apply_circular(self, enable):
269
+ if self.circular_enabled == enable:
270
+ return
271
+
272
+ self.circular_enabled = enable
273
+
274
+ for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
275
+ layer.padding_mode = 'circular' if enable else 'zeros'
276
+
277
+ def clear_comments(self):
278
+ self.comments = []
279
+ self.extra_generation_params = {}
280
+
281
+ def get_prompt_lengths(self, text):
282
+ if self.clip is None:
283
+ return "-", "-"
284
+
285
+ _, token_count = self.clip.process_texts([text])
286
+
287
+ return token_count, self.clip.get_target_prompt_token_count(token_count)
288
+
289
+ def redo_hijack(self, m):
290
+ self.undo_hijack(m)
291
+ self.hijack(m)
292
+
293
+
294
+ class EmbeddingsWithFixes(torch.nn.Module):
295
+ def __init__(self, wrapped, embeddings):
296
+ super().__init__()
297
+ self.wrapped = wrapped
298
+ self.embeddings = embeddings
299
+
300
+ def forward(self, input_ids):
301
+ batch_fixes = self.embeddings.fixes
302
+ self.embeddings.fixes = None
303
+
304
+ inputs_embeds = self.wrapped(input_ids)
305
+
306
+ if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
307
+ return inputs_embeds
308
+
309
+ vecs = []
310
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
311
+ for offset, embedding in fixes:
312
+ emb = devices.cond_cast_unet(embedding.vec)
313
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
314
+ tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
315
+
316
+ vecs.append(tensor)
317
+
318
+ return torch.stack(vecs)
319
+
320
+
321
+ def add_circular_option_to_conv_2d():
322
+ conv2d_constructor = torch.nn.Conv2d.__init__
323
+
324
+ def conv2d_constructor_circular(self, *args, **kwargs):
325
+ return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
326
+
327
+ torch.nn.Conv2d.__init__ = conv2d_constructor_circular
328
+
329
+
330
+ model_hijack = StableDiffusionModelHijack()
331
+
332
+
333
+ def register_buffer(self, name, attr):
334
+ """
335
+ Fix register buffer bug for Mac OS.
336
+ """
337
+
338
+ if type(attr) == torch.Tensor:
339
+ if attr.device != devices.device:
340
+ attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
341
+
342
+ setattr(self, name, attr)
343
+
344
+
345
+ ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
346
+ ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
modules/sd_hijack_checkpoint.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.checkpoint import checkpoint
2
+
3
+ import ldm.modules.attention
4
+ import ldm.modules.diffusionmodules.openaimodel
5
+
6
+
7
+ def BasicTransformerBlock_forward(self, x, context=None):
8
+ return checkpoint(self._forward, x, context)
9
+
10
+
11
+ def AttentionBlock_forward(self, x):
12
+ return checkpoint(self._forward, x)
13
+
14
+
15
+ def ResBlock_forward(self, x, emb):
16
+ return checkpoint(self._forward, x, emb)
17
+
18
+
19
+ stored = []
20
+
21
+
22
+ def add():
23
+ if len(stored) != 0:
24
+ return
25
+
26
+ stored.extend([
27
+ ldm.modules.attention.BasicTransformerBlock.forward,
28
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
29
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
30
+ ])
31
+
32
+ ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
33
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
34
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
35
+
36
+
37
+ def remove():
38
+ if len(stored) == 0:
39
+ return
40
+
41
+ ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
42
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
43
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
44
+
45
+ stored.clear()
46
+
modules/sd_hijack_clip.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import namedtuple
3
+
4
+ import torch
5
+
6
+ from modules import prompt_parser, devices, sd_hijack
7
+ from modules.shared import opts
8
+
9
+
10
+ class PromptChunk:
11
+ """
12
+ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
13
+ If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
14
+ Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
15
+ so just 75 tokens from prompt.
16
+ """
17
+
18
+ def __init__(self):
19
+ self.tokens = []
20
+ self.multipliers = []
21
+ self.fixes = []
22
+
23
+
24
+ PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
25
+ """An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
26
+ chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
27
+ are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
28
+
29
+
30
+ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
31
+ """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
32
+ have unlimited prompt length and assign weights to tokens in prompt.
33
+ """
34
+
35
+ def __init__(self, wrapped, hijack):
36
+ super().__init__()
37
+
38
+ self.wrapped = wrapped
39
+ """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
40
+ depending on model."""
41
+
42
+ self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
43
+ self.chunk_length = 75
44
+
45
+ self.is_trainable = getattr(wrapped, 'is_trainable', False)
46
+ self.input_key = getattr(wrapped, 'input_key', 'txt')
47
+ self.legacy_ucg_val = None
48
+
49
+ def empty_chunk(self):
50
+ """creates an empty PromptChunk and returns it"""
51
+
52
+ chunk = PromptChunk()
53
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
54
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
55
+ return chunk
56
+
57
+ def get_target_prompt_token_count(self, token_count):
58
+ """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
59
+
60
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
61
+
62
+ def tokenize(self, texts):
63
+ """Converts a batch of texts into a batch of token ids"""
64
+
65
+ raise NotImplementedError
66
+
67
+ def encode_with_transformers(self, tokens):
68
+ """
69
+ converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
70
+ All python lists with tokens are assumed to have same length, usually 77.
71
+ if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
72
+ model - can be 768 and 1024.
73
+ Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
74
+ """
75
+
76
+ raise NotImplementedError
77
+
78
+ def encode_embedding_init_text(self, init_text, nvpt):
79
+ """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
80
+ transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
81
+
82
+ raise NotImplementedError
83
+
84
+ def tokenize_line(self, line):
85
+ """
86
+ this transforms a single prompt into a list of PromptChunk objects - as many as needed to
87
+ represent the prompt.
88
+ Returns the list and the total number of tokens in the prompt.
89
+ """
90
+
91
+ if opts.enable_emphasis:
92
+ parsed = prompt_parser.parse_prompt_attention(line)
93
+ else:
94
+ parsed = [[line, 1.0]]
95
+
96
+ tokenized = self.tokenize([text for text, _ in parsed])
97
+
98
+ chunks = []
99
+ chunk = PromptChunk()
100
+ token_count = 0
101
+ last_comma = -1
102
+
103
+ def next_chunk(is_last=False):
104
+ """puts current chunk into the list of results and produces the next one - empty;
105
+ if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
106
+ nonlocal token_count
107
+ nonlocal last_comma
108
+ nonlocal chunk
109
+
110
+ if is_last:
111
+ token_count += len(chunk.tokens)
112
+ else:
113
+ token_count += self.chunk_length
114
+
115
+ to_add = self.chunk_length - len(chunk.tokens)
116
+ if to_add > 0:
117
+ chunk.tokens += [self.id_end] * to_add
118
+ chunk.multipliers += [1.0] * to_add
119
+
120
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
121
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
122
+
123
+ last_comma = -1
124
+ chunks.append(chunk)
125
+ chunk = PromptChunk()
126
+
127
+ for tokens, (text, weight) in zip(tokenized, parsed):
128
+ if text == 'BREAK' and weight == -1:
129
+ next_chunk()
130
+ continue
131
+
132
+ position = 0
133
+ while position < len(tokens):
134
+ token = tokens[position]
135
+
136
+ if token == self.comma_token:
137
+ last_comma = len(chunk.tokens)
138
+
139
+ # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
140
+ # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
141
+ elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
142
+ break_location = last_comma + 1
143
+
144
+ reloc_tokens = chunk.tokens[break_location:]
145
+ reloc_mults = chunk.multipliers[break_location:]
146
+
147
+ chunk.tokens = chunk.tokens[:break_location]
148
+ chunk.multipliers = chunk.multipliers[:break_location]
149
+
150
+ next_chunk()
151
+ chunk.tokens = reloc_tokens
152
+ chunk.multipliers = reloc_mults
153
+
154
+ if len(chunk.tokens) == self.chunk_length:
155
+ next_chunk()
156
+
157
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
158
+ if embedding is None:
159
+ chunk.tokens.append(token)
160
+ chunk.multipliers.append(weight)
161
+ position += 1
162
+ continue
163
+
164
+ emb_len = int(embedding.vec.shape[0])
165
+ if len(chunk.tokens) + emb_len > self.chunk_length:
166
+ next_chunk()
167
+
168
+ chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
169
+
170
+ chunk.tokens += [0] * emb_len
171
+ chunk.multipliers += [weight] * emb_len
172
+ position += embedding_length_in_tokens
173
+
174
+ if chunk.tokens or not chunks:
175
+ next_chunk(is_last=True)
176
+
177
+ return chunks, token_count
178
+
179
+ def process_texts(self, texts):
180
+ """
181
+ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
182
+ length, in tokens, of all texts.
183
+ """
184
+
185
+ token_count = 0
186
+
187
+ cache = {}
188
+ batch_chunks = []
189
+ for line in texts:
190
+ if line in cache:
191
+ chunks = cache[line]
192
+ else:
193
+ chunks, current_token_count = self.tokenize_line(line)
194
+ token_count = max(current_token_count, token_count)
195
+
196
+ cache[line] = chunks
197
+
198
+ batch_chunks.append(chunks)
199
+
200
+ return batch_chunks, token_count
201
+
202
+ def forward(self, texts):
203
+ """
204
+ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
205
+ Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
206
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
207
+ An example shape returned by this function can be: (2, 77, 768).
208
+ For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
209
+ Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
210
+ is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
211
+ """
212
+
213
+ if opts.use_old_emphasis_implementation:
214
+ import modules.sd_hijack_clip_old
215
+ return modules.sd_hijack_clip_old.forward_old(self, texts)
216
+
217
+ batch_chunks, token_count = self.process_texts(texts)
218
+
219
+ used_embeddings = {}
220
+ chunk_count = max([len(x) for x in batch_chunks])
221
+
222
+ zs = []
223
+ for i in range(chunk_count):
224
+ batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
225
+
226
+ tokens = [x.tokens for x in batch_chunk]
227
+ multipliers = [x.multipliers for x in batch_chunk]
228
+ self.hijack.fixes = [x.fixes for x in batch_chunk]
229
+
230
+ for fixes in self.hijack.fixes:
231
+ for _position, embedding in fixes:
232
+ used_embeddings[embedding.name] = embedding
233
+
234
+ z = self.process_tokens(tokens, multipliers)
235
+ zs.append(z)
236
+
237
+ if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:
238
+ hashes = []
239
+ for name, embedding in used_embeddings.items():
240
+ shorthash = embedding.shorthash
241
+ if not shorthash:
242
+ continue
243
+
244
+ name = name.replace(":", "").replace(",", "")
245
+ hashes.append(f"{name}: {shorthash}")
246
+
247
+ if hashes:
248
+ self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
249
+
250
+ if getattr(self.wrapped, 'return_pooled', False):
251
+ return torch.hstack(zs), zs[0].pooled
252
+ else:
253
+ return torch.hstack(zs)
254
+
255
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
256
+ """
257
+ sends one single prompt chunk to be encoded by transformers neural network.
258
+ remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
259
+ there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
260
+ Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
261
+ corresponds to one token.
262
+ """
263
+ tokens = torch.asarray(remade_batch_tokens).to(devices.device)
264
+
265
+ # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
266
+ if self.id_end != self.id_pad:
267
+ for batch_pos in range(len(remade_batch_tokens)):
268
+ index = remade_batch_tokens[batch_pos].index(self.id_end)
269
+ tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
270
+
271
+ z = self.encode_with_transformers(tokens)
272
+
273
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
274
+ batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
275
+ original_mean = z.mean()
276
+ z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
277
+ new_mean = z.mean()
278
+ z *= (original_mean / new_mean)
279
+
280
+ return z
281
+
282
+
283
+ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
284
+ def __init__(self, wrapped, hijack):
285
+ super().__init__(wrapped, hijack)
286
+ self.tokenizer = wrapped.tokenizer
287
+
288
+ vocab = self.tokenizer.get_vocab()
289
+
290
+ self.comma_token = vocab.get(',</w>', None)
291
+
292
+ self.token_mults = {}
293
+ tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
294
+ for text, ident in tokens_with_parens:
295
+ mult = 1.0
296
+ for c in text:
297
+ if c == '[':
298
+ mult /= 1.1
299
+ if c == ']':
300
+ mult *= 1.1
301
+ if c == '(':
302
+ mult *= 1.1
303
+ if c == ')':
304
+ mult /= 1.1
305
+
306
+ if mult != 1.0:
307
+ self.token_mults[ident] = mult
308
+
309
+ self.id_start = self.wrapped.tokenizer.bos_token_id
310
+ self.id_end = self.wrapped.tokenizer.eos_token_id
311
+ self.id_pad = self.id_end
312
+
313
+ def tokenize(self, texts):
314
+ tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
315
+
316
+ return tokenized
317
+
318
+ def encode_with_transformers(self, tokens):
319
+ outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
320
+
321
+ if opts.CLIP_stop_at_last_layers > 1:
322
+ z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
323
+ z = self.wrapped.transformer.text_model.final_layer_norm(z)
324
+ else:
325
+ z = outputs.last_hidden_state
326
+
327
+ return z
328
+
329
+ def encode_embedding_init_text(self, init_text, nvpt):
330
+ embedding_layer = self.wrapped.transformer.text_model.embeddings
331
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
332
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
333
+
334
+ return embedded
335
+
336
+
337
+ class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
338
+ def __init__(self, wrapped, hijack):
339
+ super().__init__(wrapped, hijack)
340
+
341
+ def encode_with_transformers(self, tokens):
342
+ outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")
343
+
344
+ if self.wrapped.layer == "last":
345
+ z = outputs.last_hidden_state
346
+ else:
347
+ z = outputs.hidden_states[self.wrapped.layer_idx]
348
+
349
+ return z
modules/sd_hijack_clip_old.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import sd_hijack_clip
2
+ from modules import shared
3
+
4
+
5
+ def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
6
+ id_start = self.id_start
7
+ id_end = self.id_end
8
+ maxlen = self.wrapped.max_length # you get to stay at 77
9
+ used_custom_terms = []
10
+ remade_batch_tokens = []
11
+ hijack_comments = []
12
+ hijack_fixes = []
13
+ token_count = 0
14
+
15
+ cache = {}
16
+ batch_tokens = self.tokenize(texts)
17
+ batch_multipliers = []
18
+ for tokens in batch_tokens:
19
+ tuple_tokens = tuple(tokens)
20
+
21
+ if tuple_tokens in cache:
22
+ remade_tokens, fixes, multipliers = cache[tuple_tokens]
23
+ else:
24
+ fixes = []
25
+ remade_tokens = []
26
+ multipliers = []
27
+ mult = 1.0
28
+
29
+ i = 0
30
+ while i < len(tokens):
31
+ token = tokens[i]
32
+
33
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
34
+
35
+ mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
36
+ if mult_change is not None:
37
+ mult *= mult_change
38
+ i += 1
39
+ elif embedding is None:
40
+ remade_tokens.append(token)
41
+ multipliers.append(mult)
42
+ i += 1
43
+ else:
44
+ emb_len = int(embedding.vec.shape[0])
45
+ fixes.append((len(remade_tokens), embedding))
46
+ remade_tokens += [0] * emb_len
47
+ multipliers += [mult] * emb_len
48
+ used_custom_terms.append((embedding.name, embedding.checksum()))
49
+ i += embedding_length_in_tokens
50
+
51
+ if len(remade_tokens) > maxlen - 2:
52
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
53
+ ovf = remade_tokens[maxlen - 2:]
54
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
55
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
56
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
57
+
58
+ token_count = len(remade_tokens)
59
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
60
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
61
+ cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
62
+
63
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
64
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
65
+
66
+ remade_batch_tokens.append(remade_tokens)
67
+ hijack_fixes.append(fixes)
68
+ batch_multipliers.append(multipliers)
69
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
70
+
71
+
72
+ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
73
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
74
+
75
+ self.hijack.comments += hijack_comments
76
+
77
+ if used_custom_terms:
78
+ embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
79
+ self.hijack.comments.append(f"Used embeddings: {embedding_names}")
80
+
81
+ self.hijack.fixes = hijack_fixes
82
+ return self.process_tokens(remade_batch_tokens, batch_multipliers)
modules/sd_hijack_inpainting.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import ldm.models.diffusion.ddpm
4
+ import ldm.models.diffusion.ddim
5
+ import ldm.models.diffusion.plms
6
+
7
+ from ldm.models.diffusion.ddim import noise_like
8
+ from ldm.models.diffusion.sampling_util import norm_thresholding
9
+
10
+
11
+ @torch.no_grad()
12
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
13
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
14
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
15
+ b, *_, device = *x.shape, x.device
16
+
17
+ def get_model_output(x, t):
18
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
19
+ e_t = self.model.apply_model(x, t, c)
20
+ else:
21
+ x_in = torch.cat([x] * 2)
22
+ t_in = torch.cat([t] * 2)
23
+
24
+ if isinstance(c, dict):
25
+ assert isinstance(unconditional_conditioning, dict)
26
+ c_in = {}
27
+ for k in c:
28
+ if isinstance(c[k], list):
29
+ c_in[k] = [
30
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
31
+ for i in range(len(c[k]))
32
+ ]
33
+ else:
34
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
35
+ else:
36
+ c_in = torch.cat([unconditional_conditioning, c])
37
+
38
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
39
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
40
+
41
+ if score_corrector is not None:
42
+ assert self.model.parameterization == "eps"
43
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
44
+
45
+ return e_t
46
+
47
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
48
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
49
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
50
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
51
+
52
+ def get_x_prev_and_pred_x0(e_t, index):
53
+ # select parameters corresponding to the currently considered timestep
54
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
55
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
56
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
57
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
58
+
59
+ # current prediction for x_0
60
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
61
+ if quantize_denoised:
62
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
63
+ if dynamic_threshold is not None:
64
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
65
+ # direction pointing to x_t
66
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
67
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
68
+ if noise_dropout > 0.:
69
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
70
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
71
+ return x_prev, pred_x0
72
+
73
+ e_t = get_model_output(x, t)
74
+ if len(old_eps) == 0:
75
+ # Pseudo Improved Euler (2nd order)
76
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
77
+ e_t_next = get_model_output(x_prev, t_next)
78
+ e_t_prime = (e_t + e_t_next) / 2
79
+ elif len(old_eps) == 1:
80
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
81
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
82
+ elif len(old_eps) == 2:
83
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
84
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
85
+ elif len(old_eps) >= 3:
86
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
87
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
88
+
89
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
90
+
91
+ return x_prev, pred_x0, e_t
92
+
93
+
94
+ def do_inpainting_hijack():
95
+ # p_sample_plms is needed because PLMS can't work with dicts as conditionings
96
+
97
+ ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
modules/sd_hijack_ip2p.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+
4
+ def should_hijack_ip2p(checkpoint_info):
5
+ from modules import sd_models_config
6
+
7
+ ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
8
+ cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
9
+
10
+ return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename
modules/sd_hijack_open_clip.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open_clip.tokenizer
2
+ import torch
3
+
4
+ from modules import sd_hijack_clip, devices
5
+ from modules.shared import opts
6
+
7
+ tokenizer = open_clip.tokenizer._tokenizer
8
+
9
+
10
+ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
11
+ def __init__(self, wrapped, hijack):
12
+ super().__init__(wrapped, hijack)
13
+
14
+ self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
15
+ self.id_start = tokenizer.encoder["<start_of_text>"]
16
+ self.id_end = tokenizer.encoder["<end_of_text>"]
17
+ self.id_pad = 0
18
+
19
+ def tokenize(self, texts):
20
+ assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
21
+
22
+ tokenized = [tokenizer.encode(text) for text in texts]
23
+
24
+ return tokenized
25
+
26
+ def encode_with_transformers(self, tokens):
27
+ # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
28
+ z = self.wrapped.encode_with_transformer(tokens)
29
+
30
+ return z
31
+
32
+ def encode_embedding_init_text(self, init_text, nvpt):
33
+ ids = tokenizer.encode(init_text)
34
+ ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
35
+ embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
36
+
37
+ return embedded
38
+
39
+
40
+ class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
41
+ def __init__(self, wrapped, hijack):
42
+ super().__init__(wrapped, hijack)
43
+
44
+ self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
45
+ self.id_start = tokenizer.encoder["<start_of_text>"]
46
+ self.id_end = tokenizer.encoder["<end_of_text>"]
47
+ self.id_pad = 0
48
+
49
+ def tokenize(self, texts):
50
+ assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
51
+
52
+ tokenized = [tokenizer.encode(text) for text in texts]
53
+
54
+ return tokenized
55
+
56
+ def encode_with_transformers(self, tokens):
57
+ d = self.wrapped.encode_with_transformer(tokens)
58
+ z = d[self.wrapped.layer]
59
+
60
+ pooled = d.get("pooled")
61
+ if pooled is not None:
62
+ z.pooled = pooled
63
+
64
+ return z
65
+
66
+ def encode_embedding_init_text(self, init_text, nvpt):
67
+ ids = tokenizer.encode(init_text)
68
+ ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
69
+ embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
70
+
71
+ return embedded
modules/sd_hijack_optimizations.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import math
3
+ import psutil
4
+
5
+ import torch
6
+ from torch import einsum
7
+
8
+ from ldm.util import default
9
+ from einops import rearrange
10
+
11
+ from modules import shared, errors, devices, sub_quadratic_attention
12
+ from modules.hypernetworks import hypernetwork
13
+
14
+ import ldm.modules.attention
15
+ import ldm.modules.diffusionmodules.model
16
+
17
+ import sgm.modules.attention
18
+ import sgm.modules.diffusionmodules.model
19
+
20
+ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
21
+ sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
22
+
23
+
24
+ class SdOptimization:
25
+ name: str = None
26
+ label: str | None = None
27
+ cmd_opt: str | None = None
28
+ priority: int = 0
29
+
30
+ def title(self):
31
+ if self.label is None:
32
+ return self.name
33
+
34
+ return f"{self.name} - {self.label}"
35
+
36
+ def is_available(self):
37
+ return True
38
+
39
+ def apply(self):
40
+ pass
41
+
42
+ def undo(self):
43
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
44
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
45
+
46
+ sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
47
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
48
+
49
+
50
+ class SdOptimizationXformers(SdOptimization):
51
+ name = "xformers"
52
+ cmd_opt = "xformers"
53
+ priority = 100
54
+
55
+ def is_available(self):
56
+ return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
57
+
58
+ def apply(self):
59
+ ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
60
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
61
+ sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
62
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
63
+
64
+
65
+ class SdOptimizationSdpNoMem(SdOptimization):
66
+ name = "sdp-no-mem"
67
+ label = "scaled dot product without memory efficient attention"
68
+ cmd_opt = "opt_sdp_no_mem_attention"
69
+ priority = 80
70
+
71
+ def is_available(self):
72
+ return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
73
+
74
+ def apply(self):
75
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
76
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
77
+ sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
78
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
79
+
80
+
81
+ class SdOptimizationSdp(SdOptimizationSdpNoMem):
82
+ name = "sdp"
83
+ label = "scaled dot product"
84
+ cmd_opt = "opt_sdp_attention"
85
+ priority = 70
86
+
87
+ def apply(self):
88
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
89
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
90
+ sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
91
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
92
+
93
+
94
+ class SdOptimizationSubQuad(SdOptimization):
95
+ name = "sub-quadratic"
96
+ cmd_opt = "opt_sub_quad_attention"
97
+ priority = 10
98
+
99
+ def apply(self):
100
+ ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
101
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
102
+ sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
103
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
104
+
105
+
106
+ class SdOptimizationV1(SdOptimization):
107
+ name = "V1"
108
+ label = "original v1"
109
+ cmd_opt = "opt_split_attention_v1"
110
+ priority = 10
111
+
112
+ def apply(self):
113
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
114
+ sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
115
+
116
+
117
+ class SdOptimizationInvokeAI(SdOptimization):
118
+ name = "InvokeAI"
119
+ cmd_opt = "opt_split_attention_invokeai"
120
+
121
+ @property
122
+ def priority(self):
123
+ return 1000 if not torch.cuda.is_available() else 10
124
+
125
+ def apply(self):
126
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
127
+ sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
128
+
129
+
130
+ class SdOptimizationDoggettx(SdOptimization):
131
+ name = "Doggettx"
132
+ cmd_opt = "opt_split_attention"
133
+ priority = 90
134
+
135
+ def apply(self):
136
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
137
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
138
+ sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
139
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
140
+
141
+
142
+ def list_optimizers(res):
143
+ res.extend([
144
+ SdOptimizationXformers(),
145
+ SdOptimizationSdpNoMem(),
146
+ SdOptimizationSdp(),
147
+ SdOptimizationSubQuad(),
148
+ SdOptimizationV1(),
149
+ SdOptimizationInvokeAI(),
150
+ SdOptimizationDoggettx(),
151
+ ])
152
+
153
+
154
+ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
155
+ try:
156
+ import xformers.ops
157
+ shared.xformers_available = True
158
+ except Exception:
159
+ errors.report("Cannot import xformers", exc_info=True)
160
+
161
+
162
+ def get_available_vram():
163
+ if shared.device.type == 'cuda':
164
+ stats = torch.cuda.memory_stats(shared.device)
165
+ mem_active = stats['active_bytes.all.current']
166
+ mem_reserved = stats['reserved_bytes.all.current']
167
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
168
+ mem_free_torch = mem_reserved - mem_active
169
+ mem_free_total = mem_free_cuda + mem_free_torch
170
+ return mem_free_total
171
+ else:
172
+ return psutil.virtual_memory().available
173
+
174
+
175
+ # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
176
+ def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
177
+ h = self.heads
178
+
179
+ q_in = self.to_q(x)
180
+ context = default(context, x)
181
+
182
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
183
+ k_in = self.to_k(context_k)
184
+ v_in = self.to_v(context_v)
185
+ del context, context_k, context_v, x
186
+
187
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
188
+ del q_in, k_in, v_in
189
+
190
+ dtype = q.dtype
191
+ if shared.opts.upcast_attn:
192
+ q, k, v = q.float(), k.float(), v.float()
193
+
194
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
195
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
196
+ for i in range(0, q.shape[0], 2):
197
+ end = i + 2
198
+ s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
199
+ s1 *= self.scale
200
+
201
+ s2 = s1.softmax(dim=-1)
202
+ del s1
203
+
204
+ r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
205
+ del s2
206
+ del q, k, v
207
+
208
+ r1 = r1.to(dtype)
209
+
210
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
211
+ del r1
212
+
213
+ return self.to_out(r2)
214
+
215
+
216
+ # taken from https://github.com/Doggettx/stable-diffusion and modified
217
+ def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
218
+ h = self.heads
219
+
220
+ q_in = self.to_q(x)
221
+ context = default(context, x)
222
+
223
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
224
+ k_in = self.to_k(context_k)
225
+ v_in = self.to_v(context_v)
226
+
227
+ dtype = q_in.dtype
228
+ if shared.opts.upcast_attn:
229
+ q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
230
+
231
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
232
+ k_in = k_in * self.scale
233
+
234
+ del context, x
235
+
236
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
237
+ del q_in, k_in, v_in
238
+
239
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
240
+
241
+ mem_free_total = get_available_vram()
242
+
243
+ gb = 1024 ** 3
244
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
245
+ modifier = 3 if q.element_size() == 2 else 2.5
246
+ mem_required = tensor_size * modifier
247
+ steps = 1
248
+
249
+ if mem_required > mem_free_total:
250
+ steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
251
+ # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
252
+ # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
253
+
254
+ if steps > 64:
255
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
256
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
257
+ f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
258
+
259
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
260
+ for i in range(0, q.shape[1], slice_size):
261
+ end = i + slice_size
262
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
263
+
264
+ s2 = s1.softmax(dim=-1, dtype=q.dtype)
265
+ del s1
266
+
267
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
268
+ del s2
269
+
270
+ del q, k, v
271
+
272
+ r1 = r1.to(dtype)
273
+
274
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
275
+ del r1
276
+
277
+ return self.to_out(r2)
278
+
279
+
280
+ # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
281
+ mem_total_gb = psutil.virtual_memory().total // (1 << 30)
282
+
283
+
284
+ def einsum_op_compvis(q, k, v):
285
+ s = einsum('b i d, b j d -> b i j', q, k)
286
+ s = s.softmax(dim=-1, dtype=s.dtype)
287
+ return einsum('b i j, b j d -> b i d', s, v)
288
+
289
+
290
+ def einsum_op_slice_0(q, k, v, slice_size):
291
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
292
+ for i in range(0, q.shape[0], slice_size):
293
+ end = i + slice_size
294
+ r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
295
+ return r
296
+
297
+
298
+ def einsum_op_slice_1(q, k, v, slice_size):
299
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
300
+ for i in range(0, q.shape[1], slice_size):
301
+ end = i + slice_size
302
+ r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
303
+ return r
304
+
305
+
306
+ def einsum_op_mps_v1(q, k, v):
307
+ if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
308
+ return einsum_op_compvis(q, k, v)
309
+ else:
310
+ slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
311
+ if slice_size % 4096 == 0:
312
+ slice_size -= 1
313
+ return einsum_op_slice_1(q, k, v, slice_size)
314
+
315
+
316
+ def einsum_op_mps_v2(q, k, v):
317
+ if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
318
+ return einsum_op_compvis(q, k, v)
319
+ else:
320
+ return einsum_op_slice_0(q, k, v, 1)
321
+
322
+
323
+ def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
324
+ size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
325
+ if size_mb <= max_tensor_mb:
326
+ return einsum_op_compvis(q, k, v)
327
+ div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
328
+ if div <= q.shape[0]:
329
+ return einsum_op_slice_0(q, k, v, q.shape[0] // div)
330
+ return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
331
+
332
+
333
+ def einsum_op_cuda(q, k, v):
334
+ stats = torch.cuda.memory_stats(q.device)
335
+ mem_active = stats['active_bytes.all.current']
336
+ mem_reserved = stats['reserved_bytes.all.current']
337
+ mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
338
+ mem_free_torch = mem_reserved - mem_active
339
+ mem_free_total = mem_free_cuda + mem_free_torch
340
+ # Divide factor of safety as there's copying and fragmentation
341
+ return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
342
+
343
+
344
+ def einsum_op(q, k, v):
345
+ if q.device.type == 'cuda':
346
+ return einsum_op_cuda(q, k, v)
347
+
348
+ if q.device.type == 'mps':
349
+ if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
350
+ return einsum_op_mps_v1(q, k, v)
351
+ return einsum_op_mps_v2(q, k, v)
352
+
353
+ # Smaller slices are faster due to L2/L3/SLC caches.
354
+ # Tested on i7 with 8MB L3 cache.
355
+ return einsum_op_tensor_mem(q, k, v, 32)
356
+
357
+
358
+ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
359
+ h = self.heads
360
+
361
+ q = self.to_q(x)
362
+ context = default(context, x)
363
+
364
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
365
+ k = self.to_k(context_k)
366
+ v = self.to_v(context_v)
367
+ del context, context_k, context_v, x
368
+
369
+ dtype = q.dtype
370
+ if shared.opts.upcast_attn:
371
+ q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
372
+
373
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
374
+ k = k * self.scale
375
+
376
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
377
+ r = einsum_op(q, k, v)
378
+ r = r.to(dtype)
379
+ return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
380
+
381
+ # -- End of code from https://github.com/invoke-ai/InvokeAI --
382
+
383
+
384
+ # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
385
+ # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
386
+ def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
387
+ assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
388
+
389
+ h = self.heads
390
+
391
+ q = self.to_q(x)
392
+ context = default(context, x)
393
+
394
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
395
+ k = self.to_k(context_k)
396
+ v = self.to_v(context_v)
397
+ del context, context_k, context_v, x
398
+
399
+ q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
400
+ k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
401
+ v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
402
+
403
+ if q.device.type == 'mps':
404
+ q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
405
+
406
+ dtype = q.dtype
407
+ if shared.opts.upcast_attn:
408
+ q, k = q.float(), k.float()
409
+
410
+ x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
411
+
412
+ x = x.to(dtype)
413
+
414
+ x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
415
+
416
+ out_proj, dropout = self.to_out
417
+ x = out_proj(x)
418
+ x = dropout(x)
419
+
420
+ return x
421
+
422
+
423
+ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
424
+ bytes_per_token = torch.finfo(q.dtype).bits//8
425
+ batch_x_heads, q_tokens, _ = q.shape
426
+ _, k_tokens, _ = k.shape
427
+ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
428
+
429
+ if chunk_threshold is None:
430
+ chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
431
+ elif chunk_threshold == 0:
432
+ chunk_threshold_bytes = None
433
+ else:
434
+ chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
435
+
436
+ if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
437
+ kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
438
+ elif kv_chunk_size_min == 0:
439
+ kv_chunk_size_min = None
440
+
441
+ if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
442
+ # the big matmul fits into our memory limit; do everything in 1 chunk,
443
+ # i.e. send it down the unchunked fast-path
444
+ kv_chunk_size = k_tokens
445
+
446
+ with devices.without_autocast(disable=q.dtype == v.dtype):
447
+ return sub_quadratic_attention.efficient_dot_product_attention(
448
+ q,
449
+ k,
450
+ v,
451
+ query_chunk_size=q_chunk_size,
452
+ kv_chunk_size=kv_chunk_size,
453
+ kv_chunk_size_min = kv_chunk_size_min,
454
+ use_checkpoint=use_checkpoint,
455
+ )
456
+
457
+
458
+ def get_xformers_flash_attention_op(q, k, v):
459
+ if not shared.cmd_opts.xformers_flash_attention:
460
+ return None
461
+
462
+ try:
463
+ flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
464
+ fw, bw = flash_attention_op
465
+ if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
466
+ return flash_attention_op
467
+ except Exception as e:
468
+ errors.display_once(e, "enabling flash attention")
469
+
470
+ return None
471
+
472
+
473
+ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
474
+ h = self.heads
475
+ q_in = self.to_q(x)
476
+ context = default(context, x)
477
+
478
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
479
+ k_in = self.to_k(context_k)
480
+ v_in = self.to_v(context_v)
481
+
482
+ q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
483
+ del q_in, k_in, v_in
484
+
485
+ dtype = q.dtype
486
+ if shared.opts.upcast_attn:
487
+ q, k, v = q.float(), k.float(), v.float()
488
+
489
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
490
+
491
+ out = out.to(dtype)
492
+
493
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
494
+ return self.to_out(out)
495
+
496
+
497
+ # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
498
+ # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
499
+ def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
500
+ batch_size, sequence_length, inner_dim = x.shape
501
+
502
+ if mask is not None:
503
+ mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
504
+ mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
505
+
506
+ h = self.heads
507
+ q_in = self.to_q(x)
508
+ context = default(context, x)
509
+
510
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
511
+ k_in = self.to_k(context_k)
512
+ v_in = self.to_v(context_v)
513
+
514
+ head_dim = inner_dim // h
515
+ q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
516
+ k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
517
+ v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
518
+
519
+ del q_in, k_in, v_in
520
+
521
+ dtype = q.dtype
522
+ if shared.opts.upcast_attn:
523
+ q, k, v = q.float(), k.float(), v.float()
524
+
525
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
526
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
527
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
528
+ )
529
+
530
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
531
+ hidden_states = hidden_states.to(dtype)
532
+
533
+ # linear proj
534
+ hidden_states = self.to_out[0](hidden_states)
535
+ # dropout
536
+ hidden_states = self.to_out[1](hidden_states)
537
+ return hidden_states
538
+
539
+
540
+ def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
541
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
542
+ return scaled_dot_product_attention_forward(self, x, context, mask)
543
+
544
+
545
+ def cross_attention_attnblock_forward(self, x):
546
+ h_ = x
547
+ h_ = self.norm(h_)
548
+ q1 = self.q(h_)
549
+ k1 = self.k(h_)
550
+ v = self.v(h_)
551
+
552
+ # compute attention
553
+ b, c, h, w = q1.shape
554
+
555
+ q2 = q1.reshape(b, c, h*w)
556
+ del q1
557
+
558
+ q = q2.permute(0, 2, 1) # b,hw,c
559
+ del q2
560
+
561
+ k = k1.reshape(b, c, h*w) # b,c,hw
562
+ del k1
563
+
564
+ h_ = torch.zeros_like(k, device=q.device)
565
+
566
+ mem_free_total = get_available_vram()
567
+
568
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
569
+ mem_required = tensor_size * 2.5
570
+ steps = 1
571
+
572
+ if mem_required > mem_free_total:
573
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
574
+
575
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
576
+ for i in range(0, q.shape[1], slice_size):
577
+ end = i + slice_size
578
+
579
+ w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
580
+ w2 = w1 * (int(c)**(-0.5))
581
+ del w1
582
+ w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
583
+ del w2
584
+
585
+ # attend to values
586
+ v1 = v.reshape(b, c, h*w)
587
+ w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
588
+ del w3
589
+
590
+ h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
591
+ del v1, w4
592
+
593
+ h2 = h_.reshape(b, c, h, w)
594
+ del h_
595
+
596
+ h3 = self.proj_out(h2)
597
+ del h2
598
+
599
+ h3 += x
600
+
601
+ return h3
602
+
603
+
604
+ def xformers_attnblock_forward(self, x):
605
+ try:
606
+ h_ = x
607
+ h_ = self.norm(h_)
608
+ q = self.q(h_)
609
+ k = self.k(h_)
610
+ v = self.v(h_)
611
+ b, c, h, w = q.shape
612
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
613
+ dtype = q.dtype
614
+ if shared.opts.upcast_attn:
615
+ q, k = q.float(), k.float()
616
+ q = q.contiguous()
617
+ k = k.contiguous()
618
+ v = v.contiguous()
619
+ out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
620
+ out = out.to(dtype)
621
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
622
+ out = self.proj_out(out)
623
+ return x + out
624
+ except NotImplementedError:
625
+ return cross_attention_attnblock_forward(self, x)
626
+
627
+
628
+ def sdp_attnblock_forward(self, x):
629
+ h_ = x
630
+ h_ = self.norm(h_)
631
+ q = self.q(h_)
632
+ k = self.k(h_)
633
+ v = self.v(h_)
634
+ b, c, h, w = q.shape
635
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
636
+ dtype = q.dtype
637
+ if shared.opts.upcast_attn:
638
+ q, k, v = q.float(), k.float(), v.float()
639
+ q = q.contiguous()
640
+ k = k.contiguous()
641
+ v = v.contiguous()
642
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
643
+ out = out.to(dtype)
644
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
645
+ out = self.proj_out(out)
646
+ return x + out
647
+
648
+
649
+ def sdp_no_mem_attnblock_forward(self, x):
650
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
651
+ return sdp_attnblock_forward(self, x)
652
+
653
+
654
+ def sub_quad_attnblock_forward(self, x):
655
+ h_ = x
656
+ h_ = self.norm(h_)
657
+ q = self.q(h_)
658
+ k = self.k(h_)
659
+ v = self.v(h_)
660
+ b, c, h, w = q.shape
661
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
662
+ q = q.contiguous()
663
+ k = k.contiguous()
664
+ v = v.contiguous()
665
+ out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
666
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
667
+ out = self.proj_out(out)
668
+ return x + out
modules/sd_hijack_unet.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from packaging import version
3
+
4
+ from modules import devices
5
+ from modules.sd_hijack_utils import CondFunc
6
+
7
+
8
+ class TorchHijackForUnet:
9
+ """
10
+ This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
11
+ this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
12
+ """
13
+
14
+ def __getattr__(self, item):
15
+ if item == 'cat':
16
+ return self.cat
17
+
18
+ if hasattr(torch, item):
19
+ return getattr(torch, item)
20
+
21
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
22
+
23
+ def cat(self, tensors, *args, **kwargs):
24
+ if len(tensors) == 2:
25
+ a, b = tensors
26
+ if a.shape[-2:] != b.shape[-2:]:
27
+ a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
28
+
29
+ tensors = (a, b)
30
+
31
+ return torch.cat(tensors, *args, **kwargs)
32
+
33
+
34
+ th = TorchHijackForUnet()
35
+
36
+
37
+ # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
38
+ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
39
+
40
+ if isinstance(cond, dict):
41
+ for y in cond.keys():
42
+ if isinstance(cond[y], list):
43
+ cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
44
+ else:
45
+ cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
46
+
47
+ with devices.autocast():
48
+ return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
49
+
50
+
51
+ class GELUHijack(torch.nn.GELU, torch.nn.Module):
52
+ def __init__(self, *args, **kwargs):
53
+ torch.nn.GELU.__init__(self, *args, **kwargs)
54
+ def forward(self, x):
55
+ if devices.unet_needs_upcast:
56
+ return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
57
+ else:
58
+ return torch.nn.GELU.forward(self, x)
59
+
60
+
61
+ ddpm_edit_hijack = None
62
+ def hijack_ddpm_edit():
63
+ global ddpm_edit_hijack
64
+ if not ddpm_edit_hijack:
65
+ CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
66
+ CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
67
+ ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
68
+
69
+
70
+ unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
71
+ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
72
+ CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
73
+ if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
74
+ CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
75
+ CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
76
+ CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
77
+
78
+ first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
79
+ first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
80
+ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
81
+ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
82
+ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
83
+
84
+ CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
85
+ CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
modules/sd_hijack_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ class CondFunc:
4
+ def __new__(cls, orig_func, sub_func, cond_func):
5
+ self = super(CondFunc, cls).__new__(cls)
6
+ if isinstance(orig_func, str):
7
+ func_path = orig_func.split('.')
8
+ for i in range(len(func_path)-1, -1, -1):
9
+ try:
10
+ resolved_obj = importlib.import_module('.'.join(func_path[:i]))
11
+ break
12
+ except ImportError:
13
+ pass
14
+ for attr_name in func_path[i:-1]:
15
+ resolved_obj = getattr(resolved_obj, attr_name)
16
+ orig_func = getattr(resolved_obj, func_path[-1])
17
+ setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
18
+ self.__init__(orig_func, sub_func, cond_func)
19
+ return lambda *args, **kwargs: self(*args, **kwargs)
20
+ def __init__(self, orig_func, sub_func, cond_func):
21
+ self.__orig_func = orig_func
22
+ self.__sub_func = sub_func
23
+ self.__cond_func = cond_func
24
+ def __call__(self, *args, **kwargs):
25
+ if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
26
+ return self.__sub_func(self.__orig_func, *args, **kwargs)
27
+ else:
28
+ return self.__orig_func(*args, **kwargs)
modules/sd_hijack_xlmr.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from modules import sd_hijack_clip, devices
4
+
5
+
6
+ class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
7
+ def __init__(self, wrapped, hijack):
8
+ super().__init__(wrapped, hijack)
9
+
10
+ self.id_start = wrapped.config.bos_token_id
11
+ self.id_end = wrapped.config.eos_token_id
12
+ self.id_pad = wrapped.config.pad_token_id
13
+
14
+ self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
15
+
16
+ def encode_with_transformers(self, tokens):
17
+ # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
18
+ # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
19
+ # layer to work with - you have to use the last
20
+
21
+ attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
22
+ features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
23
+ z = features['projection_state']
24
+
25
+ return z
26
+
27
+ def encode_embedding_init_text(self, init_text, nvpt):
28
+ embedding_layer = self.wrapped.roberta.embeddings
29
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
30
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
31
+
32
+ return embedded
modules/sd_models.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os.path
3
+ import sys
4
+ import gc
5
+ import threading
6
+
7
+ import torch
8
+ import re
9
+ import safetensors.torch
10
+ from omegaconf import OmegaConf
11
+ from os import mkdir
12
+ from urllib import request
13
+ import ldm.modules.midas as midas
14
+
15
+ from ldm.util import instantiate_from_config
16
+
17
+ from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
18
+ from modules.sd_hijack_inpainting import do_inpainting_hijack
19
+ from modules.timer import Timer
20
+ import tomesd
21
+
22
+ model_dir = "Stable-diffusion"
23
+ model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
24
+
25
+ checkpoints_list = {}
26
+ checkpoint_aliases = {}
27
+ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
28
+ checkpoints_loaded = collections.OrderedDict()
29
+
30
+
31
+ class CheckpointInfo:
32
+ def __init__(self, filename):
33
+ self.filename = filename
34
+ abspath = os.path.abspath(filename)
35
+
36
+ if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
37
+ name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
38
+ elif abspath.startswith(model_path):
39
+ name = abspath.replace(model_path, '')
40
+ else:
41
+ name = os.path.basename(filename)
42
+
43
+ if name.startswith("\\") or name.startswith("/"):
44
+ name = name[1:]
45
+
46
+ self.name = name
47
+ self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
48
+ self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
49
+ self.hash = model_hash(filename)
50
+
51
+ self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
52
+ self.shorthash = self.sha256[0:10] if self.sha256 else None
53
+
54
+ self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
55
+
56
+ self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
57
+
58
+ self.metadata = {}
59
+
60
+ _, ext = os.path.splitext(self.filename)
61
+ if ext.lower() == ".safetensors":
62
+ try:
63
+ self.metadata = read_metadata_from_safetensors(filename)
64
+ except Exception as e:
65
+ errors.display(e, f"reading checkpoint metadata: {filename}")
66
+
67
+ def register(self):
68
+ checkpoints_list[self.title] = self
69
+ for id in self.ids:
70
+ checkpoint_aliases[id] = self
71
+
72
+ def calculate_shorthash(self):
73
+ self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
74
+ if self.sha256 is None:
75
+ return
76
+
77
+ self.shorthash = self.sha256[0:10]
78
+
79
+ if self.shorthash not in self.ids:
80
+ self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
81
+
82
+ checkpoints_list.pop(self.title)
83
+ self.title = f'{self.name} [{self.shorthash}]'
84
+ self.register()
85
+
86
+ return self.shorthash
87
+
88
+
89
+ try:
90
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
91
+ from transformers import logging, CLIPModel # noqa: F401
92
+
93
+ logging.set_verbosity_error()
94
+ except Exception:
95
+ pass
96
+
97
+
98
+ def setup_model():
99
+ os.makedirs(model_path, exist_ok=True)
100
+
101
+ enable_midas_autodownload()
102
+
103
+
104
+ def checkpoint_tiles():
105
+ def convert(name):
106
+ return int(name) if name.isdigit() else name.lower()
107
+
108
+ def alphanumeric_key(key):
109
+ return [convert(c) for c in re.split('([0-9]+)', key)]
110
+
111
+ return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
112
+
113
+
114
+ def list_models():
115
+ checkpoints_list.clear()
116
+ checkpoint_aliases.clear()
117
+
118
+ cmd_ckpt = shared.cmd_opts.ckpt
119
+ if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
120
+ model_url = None
121
+ else:
122
+ model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
123
+
124
+ model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
125
+
126
+ if os.path.exists(cmd_ckpt):
127
+ checkpoint_info = CheckpointInfo(cmd_ckpt)
128
+ checkpoint_info.register()
129
+
130
+ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
131
+ elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
132
+ print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
133
+
134
+ for filename in sorted(model_list, key=str.lower):
135
+ checkpoint_info = CheckpointInfo(filename)
136
+ checkpoint_info.register()
137
+
138
+
139
+ def get_closet_checkpoint_match(search_string):
140
+ checkpoint_info = checkpoint_aliases.get(search_string, None)
141
+ if checkpoint_info is not None:
142
+ return checkpoint_info
143
+
144
+ found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
145
+ if found:
146
+ return found[0]
147
+
148
+ return None
149
+
150
+
151
+ def model_hash(filename):
152
+ """old hash that only looks at a small part of the file and is prone to collisions"""
153
+
154
+ try:
155
+ with open(filename, "rb") as file:
156
+ import hashlib
157
+ m = hashlib.sha256()
158
+
159
+ file.seek(0x100000)
160
+ m.update(file.read(0x10000))
161
+ return m.hexdigest()[0:8]
162
+ except FileNotFoundError:
163
+ return 'NOFILE'
164
+
165
+
166
+ def select_checkpoint():
167
+ """Raises `FileNotFoundError` if no checkpoints are found."""
168
+ model_checkpoint = shared.opts.sd_model_checkpoint
169
+
170
+ checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
171
+ if checkpoint_info is not None:
172
+ return checkpoint_info
173
+
174
+ if len(checkpoints_list) == 0:
175
+ error_message = "No checkpoints found. When searching for checkpoints, looked at:"
176
+ if shared.cmd_opts.ckpt is not None:
177
+ error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
178
+ error_message += f"\n - directory {model_path}"
179
+ if shared.cmd_opts.ckpt_dir is not None:
180
+ error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
181
+ error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
182
+ raise FileNotFoundError(error_message)
183
+
184
+ checkpoint_info = next(iter(checkpoints_list.values()))
185
+ if model_checkpoint is not None:
186
+ print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
187
+
188
+ return checkpoint_info
189
+
190
+
191
+ checkpoint_dict_replacements = {
192
+ 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
193
+ 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
194
+ 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
195
+ }
196
+
197
+
198
+ def transform_checkpoint_dict_key(k):
199
+ for text, replacement in checkpoint_dict_replacements.items():
200
+ if k.startswith(text):
201
+ k = replacement + k[len(text):]
202
+
203
+ return k
204
+
205
+
206
+ def get_state_dict_from_checkpoint(pl_sd):
207
+ pl_sd = pl_sd.pop("state_dict", pl_sd)
208
+ pl_sd.pop("state_dict", None)
209
+
210
+ sd = {}
211
+ for k, v in pl_sd.items():
212
+ new_key = transform_checkpoint_dict_key(k)
213
+
214
+ if new_key is not None:
215
+ sd[new_key] = v
216
+
217
+ pl_sd.clear()
218
+ pl_sd.update(sd)
219
+
220
+ return pl_sd
221
+
222
+
223
+ def read_metadata_from_safetensors(filename):
224
+ import json
225
+
226
+ with open(filename, mode="rb") as file:
227
+ metadata_len = file.read(8)
228
+ metadata_len = int.from_bytes(metadata_len, "little")
229
+ json_start = file.read(2)
230
+
231
+ assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
232
+ json_data = json_start + file.read(metadata_len-2)
233
+ json_obj = json.loads(json_data)
234
+
235
+ res = {}
236
+ for k, v in json_obj.get("__metadata__", {}).items():
237
+ res[k] = v
238
+ if isinstance(v, str) and v[0:1] == '{':
239
+ try:
240
+ res[k] = json.loads(v)
241
+ except Exception:
242
+ pass
243
+
244
+ return res
245
+
246
+
247
+ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
248
+ _, extension = os.path.splitext(checkpoint_file)
249
+ if extension.lower() == ".safetensors":
250
+ device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
251
+
252
+ if not shared.opts.disable_mmap_load_safetensors:
253
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
254
+ else:
255
+ pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
256
+ pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
257
+ else:
258
+ pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
259
+
260
+ if print_global_state and "global_step" in pl_sd:
261
+ print(f"Global Step: {pl_sd['global_step']}")
262
+
263
+ sd = get_state_dict_from_checkpoint(pl_sd)
264
+ return sd
265
+
266
+
267
+ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
268
+ sd_model_hash = checkpoint_info.calculate_shorthash()
269
+ timer.record("calculate hash")
270
+
271
+ if checkpoint_info in checkpoints_loaded:
272
+ # use checkpoint cache
273
+ print(f"Loading weights [{sd_model_hash}] from cache")
274
+ return checkpoints_loaded[checkpoint_info]
275
+
276
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
277
+ res = read_state_dict(checkpoint_info.filename)
278
+ timer.record("load weights from disk")
279
+
280
+ return res
281
+
282
+
283
+ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
284
+ sd_model_hash = checkpoint_info.calculate_shorthash()
285
+ timer.record("calculate hash")
286
+
287
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
288
+
289
+ if state_dict is None:
290
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
291
+
292
+ model.is_sdxl = hasattr(model, 'conditioner')
293
+ model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
294
+ model.is_sd1 = not model.is_sdxl and not model.is_sd2
295
+
296
+ if model.is_sdxl:
297
+ sd_models_xl.extend_sdxl(model)
298
+
299
+ model.load_state_dict(state_dict, strict=False)
300
+ del state_dict
301
+ timer.record("apply weights to model")
302
+
303
+ if shared.opts.sd_checkpoint_cache > 0:
304
+ # cache newly loaded model
305
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
306
+
307
+ if shared.cmd_opts.opt_channelslast:
308
+ model.to(memory_format=torch.channels_last)
309
+ timer.record("apply channels_last")
310
+
311
+ if not shared.cmd_opts.no_half:
312
+ vae = model.first_stage_model
313
+ depth_model = getattr(model, 'depth_model', None)
314
+
315
+ # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
316
+ if shared.cmd_opts.no_half_vae:
317
+ model.first_stage_model = None
318
+ # with --upcast-sampling, don't convert the depth model weights to float16
319
+ if shared.cmd_opts.upcast_sampling and depth_model:
320
+ model.depth_model = None
321
+
322
+ model.half()
323
+ model.first_stage_model = vae
324
+ if depth_model:
325
+ model.depth_model = depth_model
326
+
327
+ timer.record("apply half()")
328
+
329
+ devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
330
+ devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
331
+
332
+ model.first_stage_model.to(devices.dtype_vae)
333
+ timer.record("apply dtype to VAE")
334
+
335
+ # clean up cache if limit is reached
336
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
337
+ checkpoints_loaded.popitem(last=False)
338
+
339
+ model.sd_model_hash = sd_model_hash
340
+ model.sd_model_checkpoint = checkpoint_info.filename
341
+ model.sd_checkpoint_info = checkpoint_info
342
+ shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
343
+
344
+ if hasattr(model, 'logvar'):
345
+ model.logvar = model.logvar.to(devices.device) # fix for training
346
+
347
+ sd_vae.delete_base_vae()
348
+ sd_vae.clear_loaded_vae()
349
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
350
+ sd_vae.load_vae(model, vae_file, vae_source)
351
+ timer.record("load VAE")
352
+
353
+
354
+ def enable_midas_autodownload():
355
+ """
356
+ Gives the ldm.modules.midas.api.load_model function automatic downloading.
357
+
358
+ When the 512-depth-ema model, and other future models like it, is loaded,
359
+ it calls midas.api.load_model to load the associated midas depth model.
360
+ This function applies a wrapper to download the model to the correct
361
+ location automatically.
362
+ """
363
+
364
+ midas_path = os.path.join(paths.models_path, 'midas')
365
+
366
+ # stable-diffusion-stability-ai hard-codes the midas model path to
367
+ # a location that differs from where other scripts using this model look.
368
+ # HACK: Overriding the path here.
369
+ for k, v in midas.api.ISL_PATHS.items():
370
+ file_name = os.path.basename(v)
371
+ midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
372
+
373
+ midas_urls = {
374
+ "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
375
+ "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
376
+ "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
377
+ "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
378
+ }
379
+
380
+ midas.api.load_model_inner = midas.api.load_model
381
+
382
+ def load_model_wrapper(model_type):
383
+ path = midas.api.ISL_PATHS[model_type]
384
+ if not os.path.exists(path):
385
+ if not os.path.exists(midas_path):
386
+ mkdir(midas_path)
387
+
388
+ print(f"Downloading midas model weights for {model_type} to {path}")
389
+ request.urlretrieve(midas_urls[model_type], path)
390
+ print(f"{model_type} downloaded")
391
+
392
+ return midas.api.load_model_inner(model_type)
393
+
394
+ midas.api.load_model = load_model_wrapper
395
+
396
+
397
+ def repair_config(sd_config):
398
+
399
+ if not hasattr(sd_config.model.params, "use_ema"):
400
+ sd_config.model.params.use_ema = False
401
+
402
+ if hasattr(sd_config.model.params, 'unet_config'):
403
+ if shared.cmd_opts.no_half:
404
+ sd_config.model.params.unet_config.params.use_fp16 = False
405
+ elif shared.cmd_opts.upcast_sampling:
406
+ sd_config.model.params.unet_config.params.use_fp16 = True
407
+
408
+ if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
409
+ sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
410
+
411
+ # For UnCLIP-L, override the hardcoded karlo directory
412
+ if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
413
+ karlo_path = os.path.join(paths.models_path, 'karlo')
414
+ sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
415
+
416
+
417
+ sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
418
+ sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
419
+ sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
420
+ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
421
+
422
+
423
+ class SdModelData:
424
+ def __init__(self):
425
+ self.sd_model = None
426
+ self.was_loaded_at_least_once = False
427
+ self.lock = threading.Lock()
428
+
429
+ def get_sd_model(self):
430
+ if self.was_loaded_at_least_once:
431
+ return self.sd_model
432
+
433
+ if self.sd_model is None:
434
+ with self.lock:
435
+ if self.sd_model is not None or self.was_loaded_at_least_once:
436
+ return self.sd_model
437
+
438
+ try:
439
+ load_model()
440
+ except Exception as e:
441
+ errors.display(e, "loading stable diffusion model", full_traceback=True)
442
+ print("", file=sys.stderr)
443
+ print("Stable diffusion model failed to load", file=sys.stderr)
444
+ self.sd_model = None
445
+
446
+ return self.sd_model
447
+
448
+ def set_sd_model(self, v):
449
+ self.sd_model = v
450
+
451
+
452
+ model_data = SdModelData()
453
+
454
+
455
+ def get_empty_cond(sd_model):
456
+ if hasattr(sd_model, 'conditioner'):
457
+ d = sd_model.get_learned_conditioning([""])
458
+ return d['crossattn']
459
+ else:
460
+ return sd_model.cond_stage_model([""])
461
+
462
+
463
+
464
+ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
465
+ from modules import lowvram, sd_hijack
466
+ checkpoint_info = checkpoint_info or select_checkpoint()
467
+
468
+ if model_data.sd_model:
469
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
470
+ model_data.sd_model = None
471
+ gc.collect()
472
+ devices.torch_gc()
473
+
474
+ do_inpainting_hijack()
475
+
476
+ timer = Timer()
477
+
478
+ if already_loaded_state_dict is not None:
479
+ state_dict = already_loaded_state_dict
480
+ else:
481
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
482
+
483
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
484
+ clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
485
+
486
+ timer.record("find config")
487
+
488
+ sd_config = OmegaConf.load(checkpoint_config)
489
+ repair_config(sd_config)
490
+
491
+ timer.record("load config")
492
+
493
+ print(f"Creating model from config: {checkpoint_config}")
494
+
495
+ sd_model = None
496
+ try:
497
+ with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
498
+ sd_model = instantiate_from_config(sd_config.model)
499
+ except Exception:
500
+ pass
501
+
502
+ if sd_model is None:
503
+ print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
504
+ sd_model = instantiate_from_config(sd_config.model)
505
+
506
+ sd_model.used_config = checkpoint_config
507
+
508
+ timer.record("create model")
509
+
510
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
511
+
512
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
513
+ lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
514
+ else:
515
+ sd_model.to(shared.device)
516
+
517
+ timer.record("move model to device")
518
+
519
+ sd_hijack.model_hijack.hijack(sd_model)
520
+
521
+ timer.record("hijack")
522
+
523
+ sd_model.eval()
524
+ model_data.sd_model = sd_model
525
+ model_data.was_loaded_at_least_once = True
526
+
527
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
528
+
529
+ timer.record("load textual inversion embeddings")
530
+
531
+ script_callbacks.model_loaded_callback(sd_model)
532
+
533
+ timer.record("scripts callbacks")
534
+
535
+ with devices.autocast(), torch.no_grad():
536
+ sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
537
+
538
+ timer.record("calculate empty prompt")
539
+
540
+ print(f"Model loaded in {timer.summary()}.")
541
+
542
+ return sd_model
543
+
544
+
545
+ def reload_model_weights(sd_model=None, info=None):
546
+ from modules import lowvram, devices, sd_hijack
547
+ checkpoint_info = info or select_checkpoint()
548
+
549
+ if not sd_model:
550
+ sd_model = model_data.sd_model
551
+
552
+ if sd_model is None: # previous model load failed
553
+ current_checkpoint_info = None
554
+ else:
555
+ current_checkpoint_info = sd_model.sd_checkpoint_info
556
+ if sd_model.sd_model_checkpoint == checkpoint_info.filename:
557
+ return
558
+
559
+ sd_unet.apply_unet("None")
560
+
561
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
562
+ lowvram.send_everything_to_cpu()
563
+ else:
564
+ sd_model.to(devices.cpu)
565
+
566
+ sd_hijack.model_hijack.undo_hijack(sd_model)
567
+
568
+ timer = Timer()
569
+
570
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
571
+
572
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
573
+
574
+ timer.record("find config")
575
+
576
+ if sd_model is None or checkpoint_config != sd_model.used_config:
577
+ del sd_model
578
+ load_model(checkpoint_info, already_loaded_state_dict=state_dict)
579
+ return model_data.sd_model
580
+
581
+ try:
582
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
583
+ except Exception:
584
+ print("Failed to load checkpoint, restoring previous")
585
+ load_model_weights(sd_model, current_checkpoint_info, None, timer)
586
+ raise
587
+ finally:
588
+ sd_hijack.model_hijack.hijack(sd_model)
589
+ timer.record("hijack")
590
+
591
+ script_callbacks.model_loaded_callback(sd_model)
592
+ timer.record("script callbacks")
593
+
594
+ if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
595
+ sd_model.to(devices.device)
596
+ timer.record("move model to device")
597
+
598
+ print(f"Weights loaded in {timer.summary()}.")
599
+
600
+ return sd_model
601
+
602
+
603
+ def unload_model_weights(sd_model=None, info=None):
604
+ from modules import devices, sd_hijack
605
+ timer = Timer()
606
+
607
+ if model_data.sd_model:
608
+ model_data.sd_model.to(devices.cpu)
609
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
610
+ model_data.sd_model = None
611
+ sd_model = None
612
+ gc.collect()
613
+ devices.torch_gc()
614
+
615
+ print(f"Unloaded weights {timer.summary()}.")
616
+
617
+ return sd_model
618
+
619
+
620
+ def apply_token_merging(sd_model, token_merging_ratio):
621
+ """
622
+ Applies speed and memory optimizations from tomesd.
623
+ """
624
+
625
+ current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
626
+
627
+ if current_token_merging_ratio == token_merging_ratio:
628
+ return
629
+
630
+ if current_token_merging_ratio > 0:
631
+ tomesd.remove_patch(sd_model)
632
+
633
+ if token_merging_ratio > 0:
634
+ tomesd.apply_patch(
635
+ sd_model,
636
+ ratio=token_merging_ratio,
637
+ use_rand=False, # can cause issues with some samplers
638
+ merge_attn=True,
639
+ merge_crossattn=False,
640
+ merge_mlp=False
641
+ )
642
+
643
+ sd_model.applied_token_merged_ratio = token_merging_ratio
modules/sd_models_config.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+
5
+ from modules import shared, paths, sd_disable_initialization
6
+
7
+ sd_configs_path = shared.sd_configs_path
8
+ sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
9
+ sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
10
+
11
+
12
+ config_default = shared.sd_default_config
13
+ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
14
+ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
15
+ config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
16
+ config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
17
+ config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
18
+ config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
19
+ config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
20
+ config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
21
+ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
22
+ config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
23
+ config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
24
+
25
+
26
+ def is_using_v_parameterization_for_sd2(state_dict):
27
+ """
28
+ Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
29
+ """
30
+
31
+ import ldm.modules.diffusionmodules.openaimodel
32
+ from modules import devices
33
+
34
+ device = devices.cpu
35
+
36
+ with sd_disable_initialization.DisableInitialization():
37
+ unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
38
+ use_checkpoint=True,
39
+ use_fp16=False,
40
+ image_size=32,
41
+ in_channels=4,
42
+ out_channels=4,
43
+ model_channels=320,
44
+ attention_resolutions=[4, 2, 1],
45
+ num_res_blocks=2,
46
+ channel_mult=[1, 2, 4, 4],
47
+ num_head_channels=64,
48
+ use_spatial_transformer=True,
49
+ use_linear_in_transformer=True,
50
+ transformer_depth=1,
51
+ context_dim=1024,
52
+ legacy=False
53
+ )
54
+ unet.eval()
55
+
56
+ with torch.no_grad():
57
+ unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
58
+ unet.load_state_dict(unet_sd, strict=True)
59
+ unet.to(device=device, dtype=torch.float)
60
+
61
+ test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
62
+ x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
63
+
64
+ out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
65
+
66
+ return out < -1
67
+
68
+
69
+ def guess_model_config_from_state_dict(sd, filename):
70
+ sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
71
+ diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
72
+ sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
73
+
74
+ if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
75
+ return config_sdxl
76
+ if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
77
+ return config_sdxl_refiner
78
+ elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
79
+ return config_depth_model
80
+ elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
81
+ return config_unclip
82
+ elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
83
+ return config_unopenclip
84
+
85
+ if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
86
+ if diffusion_model_input.shape[1] == 9:
87
+ return config_sd2_inpainting
88
+ elif is_using_v_parameterization_for_sd2(sd):
89
+ return config_sd2v
90
+ else:
91
+ return config_sd2
92
+
93
+ if diffusion_model_input is not None:
94
+ if diffusion_model_input.shape[1] == 9:
95
+ return config_inpainting
96
+ if diffusion_model_input.shape[1] == 8:
97
+ return config_instruct_pix2pix
98
+
99
+ if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
100
+ return config_alt_diffusion
101
+
102
+ return config_default
103
+
104
+
105
+ def find_checkpoint_config(state_dict, info):
106
+ if info is None:
107
+ return guess_model_config_from_state_dict(state_dict, "")
108
+
109
+ config = find_checkpoint_config_near_filename(info)
110
+ if config is not None:
111
+ return config
112
+
113
+ return guess_model_config_from_state_dict(state_dict, info.filename)
114
+
115
+
116
+ def find_checkpoint_config_near_filename(info):
117
+ if info is None:
118
+ return None
119
+
120
+ config = f"{os.path.splitext(info.filename)[0]}.yaml"
121
+ if os.path.exists(config):
122
+ return config
123
+
124
+ return None
125
+
modules/sd_models_xl.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+ import sgm.models.diffusion
6
+ import sgm.modules.diffusionmodules.denoiser_scaling
7
+ import sgm.modules.diffusionmodules.discretizer
8
+ from modules import devices, shared, prompt_parser
9
+
10
+
11
+ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
12
+ for embedder in self.conditioner.embedders:
13
+ embedder.ucg_rate = 0.0
14
+
15
+ width = getattr(batch, 'width', 1024)
16
+ height = getattr(batch, 'height', 1024)
17
+ is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
18
+ aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
19
+
20
+ devices_args = dict(device=devices.device, dtype=devices.dtype)
21
+
22
+ sdxl_conds = {
23
+ "txt": batch,
24
+ "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
25
+ "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
26
+ "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
27
+ "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
28
+ }
29
+
30
+ force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
31
+ c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
32
+
33
+ return c
34
+
35
+
36
+ def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
37
+ return self.model(x, t, cond)
38
+
39
+
40
+ def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
41
+ return x
42
+
43
+
44
+ sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
45
+ sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
46
+ sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
47
+
48
+
49
+ def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
50
+ res = []
51
+
52
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
53
+ encoded = embedder.encode_embedding_init_text(init_text, nvpt)
54
+ res.append(encoded)
55
+
56
+ return torch.cat(res, dim=1)
57
+
58
+
59
+ def process_texts(self, texts):
60
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
61
+ return embedder.process_texts(texts)
62
+
63
+
64
+ def get_target_prompt_token_count(self, token_count):
65
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
66
+ return embedder.get_target_prompt_token_count(token_count)
67
+
68
+
69
+ # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
70
+ sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
71
+ sgm.modules.GeneralConditioner.process_texts = process_texts
72
+ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
73
+
74
+
75
+ def extend_sdxl(model):
76
+ """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
77
+
78
+ dtype = next(model.model.diffusion_model.parameters()).dtype
79
+ model.model.diffusion_model.dtype = dtype
80
+ model.model.conditioning_key = 'crossattn'
81
+ model.cond_stage_key = 'txt'
82
+ # model.cond_stage_model will be set in sd_hijack
83
+
84
+ model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
85
+
86
+ discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
87
+ model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
88
+
89
+ model.conditioner.wrapped = torch.nn.Module()
90
+
91
+
92
+ sgm.modules.attention.print = lambda *args: None
93
+ sgm.modules.diffusionmodules.model.print = lambda *args: None
94
+ sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
95
+ sgm.modules.encoders.modules.print = lambda *args: None
96
+
97
+ # this gets the code to load the vanilla attention that we override
98
+ sgm.modules.attention.SDP_IS_AVAILABLE = True
99
+ sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
modules/sd_samplers.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
2
+
3
+ # imports for functions that previously were here and are used by other modules
4
+ from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
5
+
6
+ all_samplers = [
7
+ *sd_samplers_kdiffusion.samplers_data_k_diffusion,
8
+ *sd_samplers_compvis.samplers_data_compvis,
9
+ ]
10
+ all_samplers_map = {x.name: x for x in all_samplers}
11
+
12
+ samplers = []
13
+ samplers_for_img2img = []
14
+ samplers_map = {}
15
+
16
+
17
+ def find_sampler_config(name):
18
+ if name is not None:
19
+ config = all_samplers_map.get(name, None)
20
+ else:
21
+ config = all_samplers[0]
22
+
23
+ return config
24
+
25
+
26
+ def create_sampler(name, model):
27
+ config = find_sampler_config(name)
28
+
29
+ assert config is not None, f'bad sampler name: {name}'
30
+
31
+ if model.is_sdxl and config.options.get("no_sdxl", False):
32
+ raise Exception(f"Sampler {config.name} is not supported for SDXL")
33
+
34
+ sampler = config.constructor(model)
35
+ sampler.config = config
36
+
37
+ return sampler
38
+
39
+
40
+ def set_samplers():
41
+ global samplers, samplers_for_img2img
42
+
43
+ hidden = set(shared.opts.hide_samplers)
44
+ hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
45
+
46
+ samplers = [x for x in all_samplers if x.name not in hidden]
47
+ samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
48
+
49
+ samplers_map.clear()
50
+ for sampler in all_samplers:
51
+ samplers_map[sampler.name.lower()] = sampler.name
52
+ for alias in sampler.aliases:
53
+ samplers_map[alias.lower()] = sampler.name
54
+
55
+
56
+ set_samplers()
modules/sd_samplers_common.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
6
+
7
+ from modules.shared import opts, state
8
+ import modules.shared as shared
9
+
10
+ SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
11
+
12
+
13
+ def setup_img2img_steps(p, steps=None):
14
+ if opts.img2img_fix_steps or steps is not None:
15
+ requested_steps = (steps or p.steps)
16
+ steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
17
+ t_enc = requested_steps - 1
18
+ else:
19
+ steps = p.steps
20
+ t_enc = int(min(p.denoising_strength, 0.999) * steps)
21
+
22
+ return steps, t_enc
23
+
24
+
25
+ approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
26
+
27
+
28
+ def single_sample_to_image(sample, approximation=None):
29
+ if approximation is None:
30
+ approximation = approximation_indexes.get(opts.show_progress_type, 0)
31
+
32
+ if approximation == 2:
33
+ x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
34
+ elif approximation == 1:
35
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
36
+ elif approximation == 3:
37
+ x_sample = sample * 1.5
38
+ x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
39
+ else:
40
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
41
+
42
+ x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
43
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
44
+ x_sample = x_sample.astype(np.uint8)
45
+
46
+ return Image.fromarray(x_sample)
47
+
48
+
49
+ def sample_to_image(samples, index=0, approximation=None):
50
+ return single_sample_to_image(samples[index], approximation)
51
+
52
+
53
+ def samples_to_image_grid(samples, approximation=None):
54
+ return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
55
+
56
+
57
+ def store_latent(decoded):
58
+ state.current_latent = decoded
59
+
60
+ if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
61
+ if not shared.parallel_processing_allowed:
62
+ shared.state.assign_current_image(sample_to_image(decoded))
63
+
64
+
65
+ def is_sampler_using_eta_noise_seed_delta(p):
66
+ """returns whether sampler from config will use eta noise seed delta for image creation"""
67
+
68
+ sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
69
+
70
+ eta = p.eta
71
+
72
+ if eta is None and p.sampler is not None:
73
+ eta = p.sampler.eta
74
+
75
+ if eta is None and sampler_config is not None:
76
+ eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
77
+
78
+ if eta == 0:
79
+ return False
80
+
81
+ return sampler_config.options.get("uses_ensd", False)
82
+
83
+
84
+ class InterruptedException(BaseException):
85
+ pass
86
+
87
+
88
+ if opts.randn_source == "CPU":
89
+ import torchsde._brownian.brownian_interval
90
+
91
+ def torchsde_randn(size, dtype, device, seed):
92
+ generator = torch.Generator(devices.cpu).manual_seed(int(seed))
93
+ return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
94
+
95
+ torchsde._brownian.brownian_interval._randn = torchsde_randn
modules/sd_samplers_compvis.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import ldm.models.diffusion.ddim
3
+ import ldm.models.diffusion.plms
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from modules.shared import state
9
+ from modules import sd_samplers_common, prompt_parser, shared
10
+ import modules.models.diffusion.uni_pc
11
+
12
+
13
+ samplers_data_compvis = [
14
+ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
15
+ sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
16
+ sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
17
+ ]
18
+
19
+
20
+ class VanillaStableDiffusionSampler:
21
+ def __init__(self, constructor, sd_model):
22
+ self.sampler = constructor(sd_model)
23
+ self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
24
+ self.is_plms = hasattr(self.sampler, 'p_sample_plms')
25
+ self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
26
+ self.orig_p_sample_ddim = None
27
+ if self.is_plms:
28
+ self.orig_p_sample_ddim = self.sampler.p_sample_plms
29
+ elif self.is_ddim:
30
+ self.orig_p_sample_ddim = self.sampler.p_sample_ddim
31
+ self.mask = None
32
+ self.nmask = None
33
+ self.init_latent = None
34
+ self.sampler_noises = None
35
+ self.step = 0
36
+ self.stop_at = None
37
+ self.eta = None
38
+ self.config = None
39
+ self.last_latent = None
40
+
41
+ self.conditioning_key = sd_model.model.conditioning_key
42
+
43
+ def number_of_needed_noises(self, p):
44
+ return 0
45
+
46
+ def launch_sampling(self, steps, func):
47
+ state.sampling_steps = steps
48
+ state.sampling_step = 0
49
+
50
+ try:
51
+ return func()
52
+ except sd_samplers_common.InterruptedException:
53
+ return self.last_latent
54
+
55
+ def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
56
+ x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
57
+
58
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
59
+
60
+ x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
61
+
62
+ return res
63
+
64
+ def before_sample(self, x, ts, cond, unconditional_conditioning):
65
+ if state.interrupted or state.skipped:
66
+ raise sd_samplers_common.InterruptedException
67
+
68
+ if self.stop_at is not None and self.step > self.stop_at:
69
+ raise sd_samplers_common.InterruptedException
70
+
71
+ # Have to unwrap the inpainting conditioning here to perform pre-processing
72
+ image_conditioning = None
73
+ uc_image_conditioning = None
74
+ if isinstance(cond, dict):
75
+ if self.conditioning_key == "crossattn-adm":
76
+ image_conditioning = cond["c_adm"]
77
+ uc_image_conditioning = unconditional_conditioning["c_adm"]
78
+ else:
79
+ image_conditioning = cond["c_concat"][0]
80
+ cond = cond["c_crossattn"][0]
81
+ unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
82
+
83
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
84
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
85
+
86
+ assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
87
+ cond = tensor
88
+
89
+ # for DDIM, shapes must match, we can't just process cond and uncond independently;
90
+ # filling unconditional_conditioning with repeats of the last vector to match length is
91
+ # not 100% correct but should work well enough
92
+ if unconditional_conditioning.shape[1] < cond.shape[1]:
93
+ last_vector = unconditional_conditioning[:, -1:]
94
+ last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
95
+ unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
96
+ elif unconditional_conditioning.shape[1] > cond.shape[1]:
97
+ unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
98
+
99
+ if self.mask is not None:
100
+ img_orig = self.sampler.model.q_sample(self.init_latent, ts)
101
+ x = img_orig * self.mask + self.nmask * x
102
+
103
+ # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
104
+ # Note that they need to be lists because it just concatenates them later.
105
+ if image_conditioning is not None:
106
+ if self.conditioning_key == "crossattn-adm":
107
+ cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
108
+ unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
109
+ else:
110
+ cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
111
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
112
+
113
+ return x, ts, cond, unconditional_conditioning
114
+
115
+ def update_step(self, last_latent):
116
+ if self.mask is not None:
117
+ self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
118
+ else:
119
+ self.last_latent = last_latent
120
+
121
+ sd_samplers_common.store_latent(self.last_latent)
122
+
123
+ self.step += 1
124
+ state.sampling_step = self.step
125
+ shared.total_tqdm.update()
126
+
127
+ def after_sample(self, x, ts, cond, uncond, res):
128
+ if not self.is_unipc:
129
+ self.update_step(res[1])
130
+
131
+ return x, ts, cond, uncond, res
132
+
133
+ def unipc_after_update(self, x, model_x):
134
+ self.update_step(x)
135
+
136
+ def initialize(self, p):
137
+ if self.is_ddim:
138
+ self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
139
+ else:
140
+ self.eta = 0.0
141
+
142
+ if self.eta != 0.0:
143
+ p.extra_generation_params["Eta DDIM"] = self.eta
144
+
145
+ if self.is_unipc:
146
+ keys = [
147
+ ('UniPC variant', 'uni_pc_variant'),
148
+ ('UniPC skip type', 'uni_pc_skip_type'),
149
+ ('UniPC order', 'uni_pc_order'),
150
+ ('UniPC lower order final', 'uni_pc_lower_order_final'),
151
+ ]
152
+
153
+ for name, key in keys:
154
+ v = getattr(shared.opts, key)
155
+ if v != shared.opts.get_default(key):
156
+ p.extra_generation_params[name] = v
157
+
158
+ for fieldname in ['p_sample_ddim', 'p_sample_plms']:
159
+ if hasattr(self.sampler, fieldname):
160
+ setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
161
+ if self.is_unipc:
162
+ self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
163
+
164
+ self.mask = p.mask if hasattr(p, 'mask') else None
165
+ self.nmask = p.nmask if hasattr(p, 'nmask') else None
166
+
167
+
168
+ def adjust_steps_if_invalid(self, p, num_steps):
169
+ if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
170
+ if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
171
+ num_steps = shared.opts.uni_pc_order
172
+ valid_step = 999 / (1000 // num_steps)
173
+ if valid_step == math.floor(valid_step):
174
+ return int(valid_step) + 1
175
+
176
+ return num_steps
177
+
178
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
179
+ steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
180
+ steps = self.adjust_steps_if_invalid(p, steps)
181
+ self.initialize(p)
182
+
183
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
184
+ x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
185
+
186
+ self.init_latent = x
187
+ self.last_latent = x
188
+ self.step = 0
189
+
190
+ # Wrap the conditioning models with additional image conditioning for inpainting model
191
+ if image_conditioning is not None:
192
+ if self.conditioning_key == "crossattn-adm":
193
+ conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
194
+ unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
195
+ else:
196
+ conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
197
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
198
+
199
+ samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
200
+
201
+ return samples
202
+
203
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
204
+ self.initialize(p)
205
+
206
+ self.init_latent = None
207
+ self.last_latent = x
208
+ self.step = 0
209
+
210
+ steps = self.adjust_steps_if_invalid(p, steps or p.steps)
211
+
212
+ # Wrap the conditioning models with additional image conditioning for inpainting model
213
+ # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
214
+ if image_conditioning is not None:
215
+ if self.conditioning_key == "crossattn-adm":
216
+ conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
217
+ unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
218
+ else:
219
+ conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
220
+ unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
221
+
222
+ samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
223
+
224
+ return samples_ddim
modules/sd_samplers_kdiffusion.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ import torch
3
+ import inspect
4
+ import k_diffusion.sampling
5
+ from modules import prompt_parser, devices, sd_samplers_common
6
+
7
+ from modules.shared import opts, state
8
+ import modules.shared as shared
9
+ from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
10
+ from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
11
+ from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
12
+
13
+ samplers_k_diffusion = [
14
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
15
+ ('Euler', 'sample_euler', ['k_euler'], {}),
16
+ ('LMS', 'sample_lms', ['k_lms'], {}),
17
+ ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
18
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
19
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
20
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
21
+ ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
22
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
23
+ ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
24
+ ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
25
+ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
26
+ ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
27
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
28
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
29
+ ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
30
+ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
31
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
32
+ ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
33
+ ]
34
+
35
+ samplers_data_k_diffusion = [
36
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
37
+ for label, funcname, aliases, options in samplers_k_diffusion
38
+ if hasattr(k_diffusion.sampling, funcname)
39
+ ]
40
+
41
+ sampler_extra_params = {
42
+ 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
43
+ 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
44
+ 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
45
+ }
46
+
47
+ k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
48
+ k_diffusion_scheduler = {
49
+ 'Automatic': None,
50
+ 'karras': k_diffusion.sampling.get_sigmas_karras,
51
+ 'exponential': k_diffusion.sampling.get_sigmas_exponential,
52
+ 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
53
+ }
54
+
55
+
56
+ def catenate_conds(conds):
57
+ if not isinstance(conds[0], dict):
58
+ return torch.cat(conds)
59
+
60
+ return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
61
+
62
+
63
+ def subscript_cond(cond, a, b):
64
+ if not isinstance(cond, dict):
65
+ return cond[a:b]
66
+
67
+ return {key: vec[a:b] for key, vec in cond.items()}
68
+
69
+
70
+ def pad_cond(tensor, repeats, empty):
71
+ if not isinstance(tensor, dict):
72
+ return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
73
+
74
+ tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
75
+ return tensor
76
+
77
+
78
+ class CFGDenoiser(torch.nn.Module):
79
+ """
80
+ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
81
+ that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
82
+ instead of one. Originally, the second prompt is just an empty string, but we use non-empty
83
+ negative prompt.
84
+ """
85
+
86
+ def __init__(self, model):
87
+ super().__init__()
88
+ self.inner_model = model
89
+ self.mask = None
90
+ self.nmask = None
91
+ self.init_latent = None
92
+ self.step = 0
93
+ self.image_cfg_scale = None
94
+ self.padded_cond_uncond = False
95
+
96
+ def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
97
+ denoised_uncond = x_out[-uncond.shape[0]:]
98
+ denoised = torch.clone(denoised_uncond)
99
+
100
+ for i, conds in enumerate(conds_list):
101
+ for cond_index, weight in conds:
102
+ denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
103
+
104
+ return denoised
105
+
106
+ def combine_denoised_for_edit_model(self, x_out, cond_scale):
107
+ out_cond, out_img_cond, out_uncond = x_out.chunk(3)
108
+ denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
109
+
110
+ return denoised
111
+
112
+ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
113
+ if state.interrupted or state.skipped:
114
+ raise sd_samplers_common.InterruptedException
115
+
116
+ # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
117
+ # so is_edit_model is set to False to support AND composition.
118
+ is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
119
+
120
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
121
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
122
+
123
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
124
+
125
+ batch_size = len(conds_list)
126
+ repeats = [len(conds_list[i]) for i in range(batch_size)]
127
+
128
+ if shared.sd_model.model.conditioning_key == "crossattn-adm":
129
+ image_uncond = torch.zeros_like(image_cond)
130
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
131
+ else:
132
+ image_uncond = image_cond
133
+ if isinstance(uncond, dict):
134
+ make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
135
+ else:
136
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
137
+
138
+ if not is_edit_model:
139
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
140
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
141
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
142
+ else:
143
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
144
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
145
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
146
+
147
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
148
+ cfg_denoiser_callback(denoiser_params)
149
+ x_in = denoiser_params.x
150
+ image_cond_in = denoiser_params.image_cond
151
+ sigma_in = denoiser_params.sigma
152
+ tensor = denoiser_params.text_cond
153
+ uncond = denoiser_params.text_uncond
154
+ skip_uncond = False
155
+
156
+ # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
157
+ if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
158
+ skip_uncond = True
159
+ x_in = x_in[:-batch_size]
160
+ sigma_in = sigma_in[:-batch_size]
161
+
162
+ self.padded_cond_uncond = False
163
+ if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
164
+ empty = shared.sd_model.cond_stage_model_empty_prompt
165
+ num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
166
+
167
+ if num_repeats < 0:
168
+ tensor = pad_cond(tensor, -num_repeats, empty)
169
+ self.padded_cond_uncond = True
170
+ elif num_repeats > 0:
171
+ uncond = pad_cond(uncond, num_repeats, empty)
172
+ self.padded_cond_uncond = True
173
+
174
+ if tensor.shape[1] == uncond.shape[1] or skip_uncond:
175
+ if is_edit_model:
176
+ cond_in = catenate_conds([tensor, uncond, uncond])
177
+ elif skip_uncond:
178
+ cond_in = tensor
179
+ else:
180
+ cond_in = catenate_conds([tensor, uncond])
181
+
182
+ if shared.batch_cond_uncond:
183
+ x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
184
+ else:
185
+ x_out = torch.zeros_like(x_in)
186
+ for batch_offset in range(0, x_out.shape[0], batch_size):
187
+ a = batch_offset
188
+ b = a + batch_size
189
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
190
+ else:
191
+ x_out = torch.zeros_like(x_in)
192
+ batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
193
+ for batch_offset in range(0, tensor.shape[0], batch_size):
194
+ a = batch_offset
195
+ b = min(a + batch_size, tensor.shape[0])
196
+
197
+ if not is_edit_model:
198
+ c_crossattn = subscript_cond(tensor, a, b)
199
+ else:
200
+ c_crossattn = torch.cat([tensor[a:b]], uncond)
201
+
202
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
203
+
204
+ if not skip_uncond:
205
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
206
+
207
+ denoised_image_indexes = [x[0][0] for x in conds_list]
208
+ if skip_uncond:
209
+ fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
210
+ x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
211
+
212
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
213
+ cfg_denoised_callback(denoised_params)
214
+
215
+ devices.test_for_nans(x_out, "unet")
216
+
217
+ if opts.live_preview_content == "Prompt":
218
+ sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
219
+ elif opts.live_preview_content == "Negative prompt":
220
+ sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
221
+
222
+ if is_edit_model:
223
+ denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
224
+ elif skip_uncond:
225
+ denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
226
+ else:
227
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
228
+
229
+ if self.mask is not None:
230
+ denoised = self.init_latent * self.mask + self.nmask * denoised
231
+
232
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
233
+ cfg_after_cfg_callback(after_cfg_callback_params)
234
+ denoised = after_cfg_callback_params.x
235
+
236
+ self.step += 1
237
+ return denoised
238
+
239
+
240
+ class TorchHijack:
241
+ def __init__(self, sampler_noises):
242
+ # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
243
+ # implementation.
244
+ self.sampler_noises = deque(sampler_noises)
245
+
246
+ def __getattr__(self, item):
247
+ if item == 'randn_like':
248
+ return self.randn_like
249
+
250
+ if hasattr(torch, item):
251
+ return getattr(torch, item)
252
+
253
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
254
+
255
+ def randn_like(self, x):
256
+ if self.sampler_noises:
257
+ noise = self.sampler_noises.popleft()
258
+ if noise.shape == x.shape:
259
+ return noise
260
+
261
+ if opts.randn_source == "CPU" or x.device.type == 'mps':
262
+ return torch.randn_like(x, device=devices.cpu).to(x.device)
263
+ else:
264
+ return torch.randn_like(x)
265
+
266
+
267
+ class KDiffusionSampler:
268
+ def __init__(self, funcname, sd_model):
269
+ denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
270
+
271
+ self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
272
+ self.funcname = funcname
273
+ self.func = getattr(k_diffusion.sampling, self.funcname)
274
+ self.extra_params = sampler_extra_params.get(funcname, [])
275
+ self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
276
+ self.sampler_noises = None
277
+ self.stop_at = None
278
+ self.eta = None
279
+ self.config = None # set by the function calling the constructor
280
+ self.last_latent = None
281
+ self.s_min_uncond = None
282
+
283
+ self.conditioning_key = sd_model.model.conditioning_key
284
+
285
+ def callback_state(self, d):
286
+ step = d['i']
287
+ latent = d["denoised"]
288
+ if opts.live_preview_content == "Combined":
289
+ sd_samplers_common.store_latent(latent)
290
+ self.last_latent = latent
291
+
292
+ if self.stop_at is not None and step > self.stop_at:
293
+ raise sd_samplers_common.InterruptedException
294
+
295
+ state.sampling_step = step
296
+ shared.total_tqdm.update()
297
+
298
+ def launch_sampling(self, steps, func):
299
+ state.sampling_steps = steps
300
+ state.sampling_step = 0
301
+
302
+ try:
303
+ return func()
304
+ except RecursionError:
305
+ print(
306
+ 'Encountered RecursionError during sampling, returning last latent. '
307
+ 'rho >5 with a polyexponential scheduler may cause this error. '
308
+ 'You should try to use a smaller rho value instead.'
309
+ )
310
+ return self.last_latent
311
+ except sd_samplers_common.InterruptedException:
312
+ return self.last_latent
313
+
314
+ def number_of_needed_noises(self, p):
315
+ return p.steps
316
+
317
+ def initialize(self, p):
318
+ self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
319
+ self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
320
+ self.model_wrap_cfg.step = 0
321
+ self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
322
+ self.eta = p.eta if p.eta is not None else opts.eta_ancestral
323
+ self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
324
+
325
+ k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
326
+
327
+ extra_params_kwargs = {}
328
+ for param_name in self.extra_params:
329
+ if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
330
+ extra_params_kwargs[param_name] = getattr(p, param_name)
331
+
332
+ if 'eta' in inspect.signature(self.func).parameters:
333
+ if self.eta != 1.0:
334
+ p.extra_generation_params["Eta"] = self.eta
335
+
336
+ extra_params_kwargs['eta'] = self.eta
337
+
338
+ return extra_params_kwargs
339
+
340
+ def get_sigmas(self, p, steps):
341
+ discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
342
+ if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
343
+ discard_next_to_last_sigma = True
344
+ p.extra_generation_params["Discard penultimate sigma"] = True
345
+
346
+ steps += 1 if discard_next_to_last_sigma else 0
347
+
348
+ if p.sampler_noise_scheduler_override:
349
+ sigmas = p.sampler_noise_scheduler_override(steps)
350
+ elif opts.k_sched_type != "Automatic":
351
+ m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
352
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
353
+ sigmas_kwargs = {
354
+ 'sigma_min': sigma_min,
355
+ 'sigma_max': sigma_max,
356
+ }
357
+
358
+ sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
359
+ p.extra_generation_params["Schedule type"] = opts.k_sched_type
360
+
361
+ if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
362
+ sigmas_kwargs['sigma_min'] = opts.sigma_min
363
+ p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
364
+ if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
365
+ sigmas_kwargs['sigma_max'] = opts.sigma_max
366
+ p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
367
+
368
+ default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
369
+
370
+ if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
371
+ sigmas_kwargs['rho'] = opts.rho
372
+ p.extra_generation_params["Schedule rho"] = opts.rho
373
+
374
+ sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
375
+ elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
376
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
377
+
378
+ sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
379
+ else:
380
+ sigmas = self.model_wrap.get_sigmas(steps)
381
+
382
+ if discard_next_to_last_sigma:
383
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
384
+
385
+ return sigmas
386
+
387
+ def create_noise_sampler(self, x, sigmas, p):
388
+ """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
389
+ if shared.opts.no_dpmpp_sde_batch_determinism:
390
+ return None
391
+
392
+ from k_diffusion.sampling import BrownianTreeNoiseSampler
393
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
394
+ current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
395
+ return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
396
+
397
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
398
+ steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
399
+
400
+ sigmas = self.get_sigmas(p, steps)
401
+
402
+ sigma_sched = sigmas[steps - t_enc - 1:]
403
+ xi = x + noise * sigma_sched[0]
404
+
405
+ extra_params_kwargs = self.initialize(p)
406
+ parameters = inspect.signature(self.func).parameters
407
+
408
+ if 'sigma_min' in parameters:
409
+ ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
410
+ extra_params_kwargs['sigma_min'] = sigma_sched[-2]
411
+ if 'sigma_max' in parameters:
412
+ extra_params_kwargs['sigma_max'] = sigma_sched[0]
413
+ if 'n' in parameters:
414
+ extra_params_kwargs['n'] = len(sigma_sched) - 1
415
+ if 'sigma_sched' in parameters:
416
+ extra_params_kwargs['sigma_sched'] = sigma_sched
417
+ if 'sigmas' in parameters:
418
+ extra_params_kwargs['sigmas'] = sigma_sched
419
+
420
+ if self.config.options.get('brownian_noise', False):
421
+ noise_sampler = self.create_noise_sampler(x, sigmas, p)
422
+ extra_params_kwargs['noise_sampler'] = noise_sampler
423
+
424
+ self.model_wrap_cfg.init_latent = x
425
+ self.last_latent = x
426
+ extra_args = {
427
+ 'cond': conditioning,
428
+ 'image_cond': image_conditioning,
429
+ 'uncond': unconditional_conditioning,
430
+ 'cond_scale': p.cfg_scale,
431
+ 's_min_uncond': self.s_min_uncond
432
+ }
433
+
434
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
435
+
436
+ if self.model_wrap_cfg.padded_cond_uncond:
437
+ p.extra_generation_params["Pad conds"] = True
438
+
439
+ return samples
440
+
441
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
442
+ steps = steps or p.steps
443
+
444
+ sigmas = self.get_sigmas(p, steps)
445
+
446
+ x = x * sigmas[0]
447
+
448
+ extra_params_kwargs = self.initialize(p)
449
+ parameters = inspect.signature(self.func).parameters
450
+
451
+ if 'sigma_min' in parameters:
452
+ extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
453
+ extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
454
+ if 'n' in parameters:
455
+ extra_params_kwargs['n'] = steps
456
+ else:
457
+ extra_params_kwargs['sigmas'] = sigmas
458
+
459
+ if self.config.options.get('brownian_noise', False):
460
+ noise_sampler = self.create_noise_sampler(x, sigmas, p)
461
+ extra_params_kwargs['noise_sampler'] = noise_sampler
462
+
463
+ self.last_latent = x
464
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
465
+ 'cond': conditioning,
466
+ 'image_cond': image_conditioning,
467
+ 'uncond': unconditional_conditioning,
468
+ 'cond_scale': p.cfg_scale,
469
+ 's_min_uncond': self.s_min_uncond
470
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
471
+
472
+ if self.model_wrap_cfg.padded_cond_uncond:
473
+ p.extra_generation_params["Pad conds"] = True
474
+
475
+ return samples
476
+
modules/sd_unet.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn
2
+ import ldm.modules.diffusionmodules.openaimodel
3
+
4
+ from modules import script_callbacks, shared, devices
5
+
6
+ unet_options = []
7
+ current_unet_option = None
8
+ current_unet = None
9
+
10
+
11
+ def list_unets():
12
+ new_unets = script_callbacks.list_unets_callback()
13
+
14
+ unet_options.clear()
15
+ unet_options.extend(new_unets)
16
+
17
+
18
+ def get_unet_option(option=None):
19
+ option = option or shared.opts.sd_unet
20
+
21
+ if option == "None":
22
+ return None
23
+
24
+ if option == "Automatic":
25
+ name = shared.sd_model.sd_checkpoint_info.model_name
26
+
27
+ options = [x for x in unet_options if x.model_name == name]
28
+
29
+ option = options[0].label if options else "None"
30
+
31
+ return next(iter([x for x in unet_options if x.label == option]), None)
32
+
33
+
34
+ def apply_unet(option=None):
35
+ global current_unet_option
36
+ global current_unet
37
+
38
+ new_option = get_unet_option(option)
39
+ if new_option == current_unet_option:
40
+ return
41
+
42
+ if current_unet is not None:
43
+ print(f"Dectivating unet: {current_unet.option.label}")
44
+ current_unet.deactivate()
45
+
46
+ current_unet_option = new_option
47
+ if current_unet_option is None:
48
+ current_unet = None
49
+
50
+ if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
51
+ shared.sd_model.model.diffusion_model.to(devices.device)
52
+
53
+ return
54
+
55
+ shared.sd_model.model.diffusion_model.to(devices.cpu)
56
+ devices.torch_gc()
57
+
58
+ current_unet = current_unet_option.create_unet()
59
+ current_unet.option = current_unet_option
60
+ print(f"Activating unet: {current_unet.option.label}")
61
+ current_unet.activate()
62
+
63
+
64
+ class SdUnetOption:
65
+ model_name = None
66
+ """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
67
+
68
+ label = None
69
+ """name of the unet in UI"""
70
+
71
+ def create_unet(self):
72
+ """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
73
+ raise NotImplementedError()
74
+
75
+
76
+ class SdUnet(torch.nn.Module):
77
+ def forward(self, x, timesteps, context, *args, **kwargs):
78
+ raise NotImplementedError()
79
+
80
+ def activate(self):
81
+ pass
82
+
83
+ def deactivate(self):
84
+ pass
85
+
86
+
87
+ def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
88
+ if current_unet is not None:
89
+ return current_unet.forward(x, timesteps, context, *args, **kwargs)
90
+
91
+ return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
92
+
modules/sd_vae.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import collections
3
+ from modules import paths, shared, devices, script_callbacks, sd_models
4
+ import glob
5
+ from copy import deepcopy
6
+
7
+
8
+ vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
9
+ vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
10
+ vae_dict = {}
11
+
12
+
13
+ base_vae = None
14
+ loaded_vae_file = None
15
+ checkpoint_info = None
16
+
17
+ checkpoints_loaded = collections.OrderedDict()
18
+
19
+ def get_base_vae(model):
20
+ if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
21
+ return base_vae
22
+ return None
23
+
24
+
25
+ def store_base_vae(model):
26
+ global base_vae, checkpoint_info
27
+ if checkpoint_info != model.sd_checkpoint_info:
28
+ assert not loaded_vae_file, "Trying to store non-base VAE!"
29
+ base_vae = deepcopy(model.first_stage_model.state_dict())
30
+ checkpoint_info = model.sd_checkpoint_info
31
+
32
+
33
+ def delete_base_vae():
34
+ global base_vae, checkpoint_info
35
+ base_vae = None
36
+ checkpoint_info = None
37
+
38
+
39
+ def restore_base_vae(model):
40
+ global loaded_vae_file
41
+ if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
42
+ print("Restoring base VAE")
43
+ _load_vae_dict(model, base_vae)
44
+ loaded_vae_file = None
45
+ delete_base_vae()
46
+
47
+
48
+ def get_filename(filepath):
49
+ return os.path.basename(filepath)
50
+
51
+
52
+ def refresh_vae_list():
53
+ vae_dict.clear()
54
+
55
+ paths = [
56
+ os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
57
+ os.path.join(sd_models.model_path, '**/*.vae.pt'),
58
+ os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
59
+ os.path.join(vae_path, '**/*.ckpt'),
60
+ os.path.join(vae_path, '**/*.pt'),
61
+ os.path.join(vae_path, '**/*.safetensors'),
62
+ ]
63
+
64
+ if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
65
+ paths += [
66
+ os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
67
+ os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
68
+ os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
69
+ ]
70
+
71
+ if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):
72
+ paths += [
73
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),
74
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),
75
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),
76
+ ]
77
+
78
+ candidates = []
79
+ for path in paths:
80
+ candidates += glob.iglob(path, recursive=True)
81
+
82
+ for filepath in candidates:
83
+ name = get_filename(filepath)
84
+ vae_dict[name] = filepath
85
+
86
+
87
+ def find_vae_near_checkpoint(checkpoint_file):
88
+ checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
89
+ for vae_file in vae_dict.values():
90
+ if os.path.basename(vae_file).startswith(checkpoint_path):
91
+ return vae_file
92
+
93
+ return None
94
+
95
+
96
+ def resolve_vae(checkpoint_file):
97
+ if shared.cmd_opts.vae_path is not None:
98
+ return shared.cmd_opts.vae_path, 'from commandline argument'
99
+
100
+ is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
101
+
102
+ vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
103
+ if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
104
+ return vae_near_checkpoint, 'found near the checkpoint'
105
+
106
+ if shared.opts.sd_vae == "None":
107
+ return None, None
108
+
109
+ vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
110
+ if vae_from_options is not None:
111
+ return vae_from_options, 'specified in settings'
112
+
113
+ if not is_automatic:
114
+ print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
115
+
116
+ return None, None
117
+
118
+
119
+ def load_vae_dict(filename, map_location):
120
+ vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
121
+ vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
122
+ return vae_dict_1
123
+
124
+
125
+ def load_vae(model, vae_file=None, vae_source="from unknown source"):
126
+ global vae_dict, loaded_vae_file
127
+ # save_settings = False
128
+
129
+ cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
130
+
131
+ if vae_file:
132
+ if cache_enabled and vae_file in checkpoints_loaded:
133
+ # use vae checkpoint cache
134
+ print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
135
+ store_base_vae(model)
136
+ _load_vae_dict(model, checkpoints_loaded[vae_file])
137
+ else:
138
+ assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
139
+ print(f"Loading VAE weights {vae_source}: {vae_file}")
140
+ store_base_vae(model)
141
+
142
+ vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
143
+ _load_vae_dict(model, vae_dict_1)
144
+
145
+ if cache_enabled:
146
+ # cache newly loaded vae
147
+ checkpoints_loaded[vae_file] = vae_dict_1.copy()
148
+
149
+ # clean up cache if limit is reached
150
+ if cache_enabled:
151
+ while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
152
+ checkpoints_loaded.popitem(last=False) # LRU
153
+
154
+ # If vae used is not in dict, update it
155
+ # It will be removed on refresh though
156
+ vae_opt = get_filename(vae_file)
157
+ if vae_opt not in vae_dict:
158
+ vae_dict[vae_opt] = vae_file
159
+
160
+ elif loaded_vae_file:
161
+ restore_base_vae(model)
162
+
163
+ loaded_vae_file = vae_file
164
+
165
+
166
+ # don't call this from outside
167
+ def _load_vae_dict(model, vae_dict_1):
168
+ model.first_stage_model.load_state_dict(vae_dict_1)
169
+ model.first_stage_model.to(devices.dtype_vae)
170
+
171
+
172
+ def clear_loaded_vae():
173
+ global loaded_vae_file
174
+ loaded_vae_file = None
175
+
176
+
177
+ unspecified = object()
178
+
179
+
180
+ def reload_vae_weights(sd_model=None, vae_file=unspecified):
181
+ from modules import lowvram, devices, sd_hijack
182
+
183
+ if not sd_model:
184
+ sd_model = shared.sd_model
185
+
186
+ checkpoint_info = sd_model.sd_checkpoint_info
187
+ checkpoint_file = checkpoint_info.filename
188
+
189
+ if vae_file == unspecified:
190
+ vae_file, vae_source = resolve_vae(checkpoint_file)
191
+ else:
192
+ vae_source = "from function argument"
193
+
194
+ if loaded_vae_file == vae_file:
195
+ return
196
+
197
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
198
+ lowvram.send_everything_to_cpu()
199
+ else:
200
+ sd_model.to(devices.cpu)
201
+
202
+ sd_hijack.model_hijack.undo_hijack(sd_model)
203
+
204
+ load_vae(sd_model, vae_file, vae_source)
205
+
206
+ sd_hijack.model_hijack.hijack(sd_model)
207
+ script_callbacks.model_loaded_callback(sd_model)
208
+
209
+ if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
210
+ sd_model.to(devices.device)
211
+
212
+ print("VAE weights loaded.")
213
+ return sd_model
modules/sd_vae_approx.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from modules import devices, paths, shared
6
+
7
+ sd_vae_approx_models = {}
8
+
9
+
10
+ class VAEApprox(nn.Module):
11
+ def __init__(self):
12
+ super(VAEApprox, self).__init__()
13
+ self.conv1 = nn.Conv2d(4, 8, (7, 7))
14
+ self.conv2 = nn.Conv2d(8, 16, (5, 5))
15
+ self.conv3 = nn.Conv2d(16, 32, (3, 3))
16
+ self.conv4 = nn.Conv2d(32, 64, (3, 3))
17
+ self.conv5 = nn.Conv2d(64, 32, (3, 3))
18
+ self.conv6 = nn.Conv2d(32, 16, (3, 3))
19
+ self.conv7 = nn.Conv2d(16, 8, (3, 3))
20
+ self.conv8 = nn.Conv2d(8, 3, (3, 3))
21
+
22
+ def forward(self, x):
23
+ extra = 11
24
+ x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
25
+ x = nn.functional.pad(x, (extra, extra, extra, extra))
26
+
27
+ for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
28
+ x = layer(x)
29
+ x = nn.functional.leaky_relu(x, 0.1)
30
+
31
+ return x
32
+
33
+
34
+ def download_model(model_path, model_url):
35
+ if not os.path.exists(model_path):
36
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
37
+
38
+ print(f'Downloading VAEApprox model to: {model_path}')
39
+ torch.hub.download_url_to_file(model_url, model_path)
40
+
41
+
42
+ def model():
43
+ model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
44
+ loaded_model = sd_vae_approx_models.get(model_name)
45
+
46
+ if loaded_model is None:
47
+ model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
48
+ if not os.path.exists(model_path):
49
+ model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name)
50
+
51
+ if not os.path.exists(model_path):
52
+ model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
53
+ download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
54
+
55
+ loaded_model = VAEApprox()
56
+ loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
57
+ loaded_model.eval()
58
+ loaded_model.to(devices.device, devices.dtype)
59
+ sd_vae_approx_models[model_name] = loaded_model
60
+
61
+ return loaded_model
62
+
63
+
64
+ def cheap_approximation(sample):
65
+ # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
66
+
67
+ if shared.sd_model.is_sdxl:
68
+ coeffs = [
69
+ [ 0.3448, 0.4168, 0.4395],
70
+ [-0.1953, -0.0290, 0.0250],
71
+ [ 0.1074, 0.0886, -0.0163],
72
+ [-0.3730, -0.2499, -0.2088],
73
+ ]
74
+ else:
75
+ coeffs = [
76
+ [ 0.298, 0.207, 0.208],
77
+ [ 0.187, 0.286, 0.173],
78
+ [-0.158, 0.189, 0.264],
79
+ [-0.184, -0.271, -0.473],
80
+ ]
81
+
82
+ coefs = torch.tensor(coeffs).to(sample.device)
83
+
84
+ x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
85
+
86
+ return x_sample
modules/sd_vae_taesd.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tiny AutoEncoder for Stable Diffusion
3
+ (DNN for encoding / decoding SD's latent space)
4
+
5
+ https://github.com/madebyollin/taesd
6
+ """
7
+ import os
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from modules import devices, paths_internal, shared
12
+
13
+ sd_vae_taesd_models = {}
14
+
15
+
16
+ def conv(n_in, n_out, **kwargs):
17
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
18
+
19
+
20
+ class Clamp(nn.Module):
21
+ @staticmethod
22
+ def forward(x):
23
+ return torch.tanh(x / 3) * 3
24
+
25
+
26
+ class Block(nn.Module):
27
+ def __init__(self, n_in, n_out):
28
+ super().__init__()
29
+ self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
30
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
31
+ self.fuse = nn.ReLU()
32
+
33
+ def forward(self, x):
34
+ return self.fuse(self.conv(x) + self.skip(x))
35
+
36
+
37
+ def decoder():
38
+ return nn.Sequential(
39
+ Clamp(), conv(4, 64), nn.ReLU(),
40
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
41
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
42
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
43
+ Block(64, 64), conv(64, 3),
44
+ )
45
+
46
+
47
+ class TAESD(nn.Module):
48
+ latent_magnitude = 3
49
+ latent_shift = 0.5
50
+
51
+ def __init__(self, decoder_path="taesd_decoder.pth"):
52
+ """Initialize pretrained TAESD on the given device from the given checkpoints."""
53
+ super().__init__()
54
+ self.decoder = decoder()
55
+ self.decoder.load_state_dict(
56
+ torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
57
+
58
+ @staticmethod
59
+ def unscale_latents(x):
60
+ """[0, 1] -> raw latents"""
61
+ return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
62
+
63
+
64
+ def download_model(model_path, model_url):
65
+ if not os.path.exists(model_path):
66
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
67
+
68
+ print(f'Downloading TAESD decoder to: {model_path}')
69
+ torch.hub.download_url_to_file(model_url, model_path)
70
+
71
+
72
+ def model():
73
+ model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
74
+ loaded_model = sd_vae_taesd_models.get(model_name)
75
+
76
+ if loaded_model is None:
77
+ model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
78
+ download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
79
+
80
+ if os.path.exists(model_path):
81
+ loaded_model = TAESD(model_path)
82
+ loaded_model.eval()
83
+ loaded_model.to(devices.device, devices.dtype)
84
+ sd_vae_taesd_models[model_name] = loaded_model
85
+ else:
86
+ raise FileNotFoundError('TAESD model not found')
87
+
88
+ return loaded_model.decoder
modules/shared.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+ import os
4
+ import re
5
+ import sys
6
+ import threading
7
+ import time
8
+ import logging
9
+
10
+ import gradio as gr
11
+ import torch
12
+ import tqdm
13
+
14
+ import launch
15
+ import modules.interrogate
16
+ import modules.memmon
17
+ import modules.styles
18
+ import modules.devices as devices
19
+ from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
20
+
21
+ from modules.generation_parameters_copypaste import infotext_to_setting_name_mapping
22
+ from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
23
+ from ldm.models.diffusion.ddpm import LatentDiffusion
24
+ from typing import Optional
25
+
26
+ log = logging.getLogger(__name__)
27
+
28
+ demo = None
29
+
30
+ parser = cmd_args.parser
31
+
32
+ script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
33
+ script_loading.preload_extensions(extensions_builtin_dir, parser)
34
+
35
+ if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
36
+ cmd_opts = parser.parse_args()
37
+ else:
38
+ cmd_opts, _ = parser.parse_known_args()
39
+
40
+
41
+ restricted_opts = {
42
+ "samples_filename_pattern",
43
+ "directories_filename_pattern",
44
+ "outdir_samples",
45
+ "outdir_txt2img_samples",
46
+ "outdir_img2img_samples",
47
+ "outdir_extras_samples",
48
+ "outdir_grids",
49
+ "outdir_txt2img_grids",
50
+ "outdir_save",
51
+ "outdir_init_images"
52
+ }
53
+
54
+ # https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
55
+ gradio_hf_hub_themes = [
56
+ "gradio/glass",
57
+ "gradio/monochrome",
58
+ "gradio/seafoam",
59
+ "gradio/soft",
60
+ "freddyaboulton/dracula_revamped",
61
+ "gradio/dracula_test",
62
+ "abidlabs/dracula_test",
63
+ "abidlabs/pakistan",
64
+ "dawood/microsoft_windows",
65
+ "ysharma/steampunk"
66
+ ]
67
+
68
+
69
+ cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
70
+
71
+ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
72
+ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
73
+
74
+ devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
75
+ devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
76
+
77
+ device = devices.device
78
+ weight_load_location = None if cmd_opts.lowram else "cpu"
79
+
80
+ batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
81
+ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
82
+ xformers_available = False
83
+ config_filename = cmd_opts.ui_settings_file
84
+
85
+ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
86
+ hypernetworks = {}
87
+ loaded_hypernetworks = []
88
+
89
+
90
+ def reload_hypernetworks():
91
+ from modules.hypernetworks import hypernetwork
92
+ global hypernetworks
93
+
94
+ hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
95
+
96
+
97
+ class State:
98
+ skipped = False
99
+ interrupted = False
100
+ job = ""
101
+ job_no = 0
102
+ job_count = 0
103
+ processing_has_refined_job_count = False
104
+ job_timestamp = '0'
105
+ sampling_step = 0
106
+ sampling_steps = 0
107
+ current_latent = None
108
+ current_image = None
109
+ current_image_sampling_step = 0
110
+ id_live_preview = 0
111
+ textinfo = None
112
+ time_start = None
113
+ server_start = None
114
+ _server_command_signal = threading.Event()
115
+ _server_command: Optional[str] = None
116
+
117
+ @property
118
+ def need_restart(self) -> bool:
119
+ # Compatibility getter for need_restart.
120
+ return self.server_command == "restart"
121
+
122
+ @need_restart.setter
123
+ def need_restart(self, value: bool) -> None:
124
+ # Compatibility setter for need_restart.
125
+ if value:
126
+ self.server_command = "restart"
127
+
128
+ @property
129
+ def server_command(self):
130
+ return self._server_command
131
+
132
+ @server_command.setter
133
+ def server_command(self, value: Optional[str]) -> None:
134
+ """
135
+ Set the server command to `value` and signal that it's been set.
136
+ """
137
+ self._server_command = value
138
+ self._server_command_signal.set()
139
+
140
+ def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
141
+ """
142
+ Wait for server command to get set; return and clear the value and signal.
143
+ """
144
+ if self._server_command_signal.wait(timeout):
145
+ self._server_command_signal.clear()
146
+ req = self._server_command
147
+ self._server_command = None
148
+ return req
149
+ return None
150
+
151
+ def request_restart(self) -> None:
152
+ self.interrupt()
153
+ self.server_command = "restart"
154
+ log.info("Received restart request")
155
+
156
+ def skip(self):
157
+ self.skipped = True
158
+ log.info("Received skip request")
159
+
160
+ def interrupt(self):
161
+ self.interrupted = True
162
+ log.info("Received interrupt request")
163
+
164
+ def nextjob(self):
165
+ if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
166
+ self.do_set_current_image()
167
+
168
+ self.job_no += 1
169
+ self.sampling_step = 0
170
+ self.current_image_sampling_step = 0
171
+
172
+ def dict(self):
173
+ obj = {
174
+ "skipped": self.skipped,
175
+ "interrupted": self.interrupted,
176
+ "job": self.job,
177
+ "job_count": self.job_count,
178
+ "job_timestamp": self.job_timestamp,
179
+ "job_no": self.job_no,
180
+ "sampling_step": self.sampling_step,
181
+ "sampling_steps": self.sampling_steps,
182
+ }
183
+
184
+ return obj
185
+
186
+ def begin(self, job: str = "(unknown)"):
187
+ self.sampling_step = 0
188
+ self.job_count = -1
189
+ self.processing_has_refined_job_count = False
190
+ self.job_no = 0
191
+ self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
192
+ self.current_latent = None
193
+ self.current_image = None
194
+ self.current_image_sampling_step = 0
195
+ self.id_live_preview = 0
196
+ self.skipped = False
197
+ self.interrupted = False
198
+ self.textinfo = None
199
+ self.time_start = time.time()
200
+ self.job = job
201
+ devices.torch_gc()
202
+ log.info("Starting job %s", job)
203
+
204
+ def end(self):
205
+ duration = time.time() - self.time_start
206
+ log.info("Ending job %s (%.2f seconds)", self.job, duration)
207
+ self.job = ""
208
+ self.job_count = 0
209
+
210
+ devices.torch_gc()
211
+
212
+ def set_current_image(self):
213
+ """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
214
+ if not parallel_processing_allowed:
215
+ return
216
+
217
+ if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
218
+ self.do_set_current_image()
219
+
220
+ def do_set_current_image(self):
221
+ if self.current_latent is None:
222
+ return
223
+
224
+ import modules.sd_samplers
225
+ if opts.show_progress_grid:
226
+ self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
227
+ else:
228
+ self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
229
+
230
+ self.current_image_sampling_step = self.sampling_step
231
+
232
+ def assign_current_image(self, image):
233
+ self.current_image = image
234
+ self.id_live_preview += 1
235
+
236
+
237
+ state = State()
238
+ state.server_start = time.time()
239
+
240
+ styles_filename = cmd_opts.styles_file
241
+ prompt_styles = modules.styles.StyleDatabase(styles_filename)
242
+
243
+ interrogator = modules.interrogate.InterrogateModels("interrogate")
244
+
245
+ face_restorers = []
246
+
247
+
248
+ class OptionInfo:
249
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
250
+ self.default = default
251
+ self.label = label
252
+ self.component = component
253
+ self.component_args = component_args
254
+ self.onchange = onchange
255
+ self.section = section
256
+ self.refresh = refresh
257
+
258
+ self.comment_before = comment_before
259
+ """HTML text that will be added after label in UI"""
260
+
261
+ self.comment_after = comment_after
262
+ """HTML text that will be added before label in UI"""
263
+
264
+ def link(self, label, url):
265
+ self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
266
+ return self
267
+
268
+ def js(self, label, js_func):
269
+ self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
270
+ return self
271
+
272
+ def info(self, info):
273
+ self.comment_after += f"<span class='info'>({info})</span>"
274
+ return self
275
+
276
+ def html(self, html):
277
+ self.comment_after += html
278
+ return self
279
+
280
+ def needs_restart(self):
281
+ self.comment_after += " <span class='info'>(requires restart)</span>"
282
+ return self
283
+
284
+
285
+
286
+
287
+ def options_section(section_identifier, options_dict):
288
+ for v in options_dict.values():
289
+ v.section = section_identifier
290
+
291
+ return options_dict
292
+
293
+
294
+ def list_checkpoint_tiles():
295
+ import modules.sd_models
296
+ return modules.sd_models.checkpoint_tiles()
297
+
298
+
299
+ def refresh_checkpoints():
300
+ import modules.sd_models
301
+ return modules.sd_models.list_models()
302
+
303
+
304
+ def list_samplers():
305
+ import modules.sd_samplers
306
+ return modules.sd_samplers.all_samplers
307
+
308
+
309
+
310
+ hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
311
+ tab_names = []
312
+
313
+ options_templates = {}
314
+
315
+ options_templates.update(options_section(('saving-images', "Saving images/grids"), {
316
+ "samples_save": OptionInfo(True, "Always save all generated images"),
317
+ "samples_format": OptionInfo('png', 'File format for images'),
318
+ "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
319
+ "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
320
+
321
+ "grid_save": OptionInfo(True, "Always save all generated image grids"),
322
+ "grid_format": OptionInfo('png', 'File format for grids'),
323
+ "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
324
+ "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
325
+ "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
326
+ "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
327
+ "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
328
+ "font": OptionInfo("", "Font for image grids that have text"),
329
+ "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
330
+ "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
331
+ "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
332
+
333
+ "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
334
+ "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
335
+ "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
336
+ "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
337
+ "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
338
+ "save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
339
+ "save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
340
+ "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
341
+ "webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
342
+ "export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
343
+ "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
344
+ "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
345
+ "img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
346
+
347
+ "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
348
+ "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
349
+ "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
350
+ "save_init_img": OptionInfo(False, "Save init images when using img2img"),
351
+
352
+ "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
353
+ "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
354
+
355
+ }))
356
+
357
+ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
358
+ "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
359
+ "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
360
+ "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
361
+ "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
362
+ "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
363
+ "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
364
+ "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
365
+ "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
366
+ "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
367
+ }))
368
+
369
+ options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
370
+ "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
371
+ "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
372
+ "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
373
+ "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
374
+ "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
375
+ }))
376
+
377
+ options_templates.update(options_section(('upscaling', "Upscaling"), {
378
+ "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
379
+ "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
380
+ "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
381
+ "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
382
+ }))
383
+
384
+ options_templates.update(options_section(('face-restoration', "Face restoration"), {
385
+ "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
386
+ "code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
387
+ "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
388
+ }))
389
+
390
+ options_templates.update(options_section(('system', "System"), {
391
+ "show_warnings": OptionInfo(False, "Show warnings in console."),
392
+ "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
393
+ "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
394
+ "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
395
+ "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
396
+ "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
397
+ "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
398
+ }))
399
+
400
+ options_templates.update(options_section(('training', "Training"), {
401
+ "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
402
+ "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
403
+ "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
404
+ "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
405
+ "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
406
+ "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
407
+ "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
408
+ "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
409
+ "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
410
+ "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
411
+ "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
412
+ "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
413
+ }))
414
+
415
+ options_templates.update(options_section(('sd', "Stable Diffusion"), {
416
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
417
+ "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
418
+ "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
419
+ "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
420
+ "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
421
+ "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
422
+ "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
423
+ "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
424
+ "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
425
+ "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
426
+ "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
427
+ "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
428
+ "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
429
+ "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
430
+ "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
431
+ "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
432
+ "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
433
+ "sd_max_resolution": OptionInfo(2048, "Max resolution output for txt2img and img2img"),
434
+ "ignore_overrides": OptionInfo([], "Ignore Overrides", gr.CheckboxGroup, lambda: {"choices": [x[0] for x in infotext_to_setting_name_mapping]}),
435
+ "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
436
+ "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
437
+ }))
438
+
439
+ options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
440
+ "sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
441
+ "sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
442
+ "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
443
+ "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
444
+ }))
445
+
446
+ options_templates.update(options_section(('optimizations', "Optimizations"), {
447
+ "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
448
+ "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
449
+ "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
450
+ "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
451
+ "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
452
+ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
453
+ "experimental_persistent_cond_cache": OptionInfo(False, "persistent cond cache").info("Experimental, keep cond caches across jobs, reduce overhead."),
454
+ }))
455
+
456
+ options_templates.update(options_section(('compatibility', "Compatibility"), {
457
+ "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
458
+ "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
459
+ "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
460
+ "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
461
+ "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
462
+ "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
463
+ }))
464
+
465
+ options_templates.update(options_section(('interrogate', "Interrogate Options"), {
466
+ "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
467
+ "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
468
+ "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
469
+ "interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
470
+ "interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
471
+ "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
472
+ "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
473
+ "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
474
+ "deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
475
+ "deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
476
+ "deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
477
+ "deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
478
+ }))
479
+
480
+ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
481
+ "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
482
+ "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
483
+ #"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
484
+ "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
485
+ #"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
486
+ #"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
487
+ "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
488
+ "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
489
+
490
+ "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
491
+ "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
492
+ "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
493
+ "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
494
+ "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
495
+ "extra_networks_default_visibility": OptionInfo(True, "Extra Networks default visibility"),
496
+ "extra_networks_cards_size": OptionInfo(1, "Card size for extra networks", gr.Slider, {"minimum": 0.8, "maximum": 2, "step": 0.1}),
497
+ "extra_networks_cards_visible_rows": OptionInfo(1, "Visible card rows for extra networks", gr.Slider, {"minimum": 1, "maximum": 3, "step": 1}),
498
+ "extra_networks_aside": OptionInfo(True, "Extra Networks aside view"),
499
+ }))
500
+
501
+ options_templates.update(options_section(('ui', "User interface"), {
502
+ "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
503
+ "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
504
+ "img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
505
+ "return_grid": OptionInfo(True, "Show grid in results for web"),
506
+ "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
507
+ "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
508
+ "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
509
+ "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
510
+ "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
511
+ "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
512
+ "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
513
+ "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
514
+ "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
515
+ "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
516
+ "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
517
+ "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(),
518
+ "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
519
+ "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
520
+ "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
521
+ "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
522
+ "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
523
+ "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
524
+ "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
525
+ "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
526
+ "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
527
+ "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
528
+ "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings"),
529
+ "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
530
+ "ui_hidden_tabs": OptionInfo("", "Hidden Tabs"),
531
+ "ui_header_tabs": OptionInfo("", "Header Tabs"),
532
+ "ui_views_order": OptionInfo("row-reverse", "Interface order input/parameters | output/preview", gr.Radio, {"choices": ["row", "row-reverse"]}),
533
+ "ui_output_image_fit": OptionInfo("Scale-down", "Generated image fit method", gr.Radio, {"choices": ["Scale-down", "Contain"]}),
534
+ "ui_show_range_ticks": OptionInfo(True, "Show ticks for range sliders"),
535
+ "ui_dispatch_input_release": OptionInfo(True, "Dispatch event change on release, for slider and input number components"),
536
+ "ui_no_slider_layout": OptionInfo(False, "No sliders compact layout mode"),
537
+ "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
538
+ }))
539
+
540
+ options_templates.update(options_section(('infotext', "Infotext"), {
541
+ "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
542
+ "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
543
+ "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
544
+ "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
545
+ "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
546
+ "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
547
+ <li>Ignore: keep prompt and styles dropdown as it is.</li>
548
+ <li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
549
+ <li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
550
+ <li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
551
+ </ul>"""),
552
+
553
+ }))
554
+
555
+ options_templates.update(options_section(('ui', "Live previews"), {
556
+ "show_progressbar": OptionInfo(True, "Show progressbar"),
557
+ "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
558
+ "live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
559
+ "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
560
+ "show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
561
+ "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
562
+ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
563
+ "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
564
+ "live_preview_image_fit": OptionInfo("Scale-down", "Live preview image fit method", gr.Radio, {"choices": ["Scale-down", "Contain"]}),
565
+ }))
566
+
567
+ options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
568
+ "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(),
569
+ "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
570
+ "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
571
+ "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
572
+ 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
573
+ 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
574
+ 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
575
+ 'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
576
+ 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
577
+ 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
578
+ 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
579
+ 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
580
+ 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
581
+ 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
582
+ 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
583
+ 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
584
+ 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
585
+ }))
586
+
587
+ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
588
+ 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
589
+ 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
590
+ 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
591
+ }))
592
+
593
+ options_templates.update(options_section((None, "Hidden options"), {
594
+ "disabled_extensions": OptionInfo([], "Disable these extensions"),
595
+ "disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
596
+ "restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
597
+ "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
598
+ }))
599
+
600
+
601
+ options_templates.update()
602
+
603
+
604
+ class Options:
605
+ data = None
606
+ data_labels = options_templates
607
+ typemap = {int: float}
608
+
609
+ def __init__(self):
610
+ self.data = {k: v.default for k, v in self.data_labels.items()}
611
+
612
+ def __setattr__(self, key, value):
613
+ if self.data is not None:
614
+ if key in self.data or key in self.data_labels:
615
+ assert not cmd_opts.freeze_settings, "changing settings is disabled"
616
+
617
+ info = opts.data_labels.get(key, None)
618
+ comp_args = info.component_args if info else None
619
+ if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
620
+ raise RuntimeError(f"not possible to set {key} because it is restricted")
621
+
622
+ if cmd_opts.hide_ui_dir_config and key in restricted_opts:
623
+ raise RuntimeError(f"not possible to set {key} because it is restricted")
624
+
625
+ self.data[key] = value
626
+ return
627
+
628
+ return super(Options, self).__setattr__(key, value)
629
+
630
+ def __getattr__(self, item):
631
+ if self.data is not None:
632
+ if item in self.data:
633
+ return self.data[item]
634
+
635
+ if item in self.data_labels:
636
+ return self.data_labels[item].default
637
+
638
+ return super(Options, self).__getattribute__(item)
639
+
640
+ def set(self, key, value):
641
+ """sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
642
+
643
+ oldval = self.data.get(key, None)
644
+ if oldval == value:
645
+ return False
646
+
647
+ try:
648
+ setattr(self, key, value)
649
+ except RuntimeError:
650
+ return False
651
+
652
+ if self.data_labels[key].onchange is not None:
653
+ try:
654
+ self.data_labels[key].onchange()
655
+ except Exception as e:
656
+ errors.display(e, f"changing setting {key} to {value}")
657
+ setattr(self, key, oldval)
658
+ return False
659
+
660
+ return True
661
+
662
+ def get_default(self, key):
663
+ """returns the default value for the key"""
664
+
665
+ data_label = self.data_labels.get(key)
666
+ if data_label is None:
667
+ return None
668
+
669
+ return data_label.default
670
+
671
+ def save(self, filename):
672
+ assert not cmd_opts.freeze_settings, "saving settings is disabled"
673
+
674
+ with open(filename, "w", encoding="utf8") as file:
675
+ json.dump(self.data, file, indent=4)
676
+
677
+ def same_type(self, x, y):
678
+ if x is None or y is None:
679
+ return True
680
+
681
+ type_x = self.typemap.get(type(x), type(x))
682
+ type_y = self.typemap.get(type(y), type(y))
683
+
684
+ return type_x == type_y
685
+
686
+ def load(self, filename):
687
+ with open(filename, "r", encoding="utf8") as file:
688
+ self.data = json.load(file)
689
+
690
+ # 1.1.1 quicksettings list migration
691
+ if self.data.get('quicksettings') is not None:
692
+ self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
693
+
694
+ # 1.4.0 ui_reorder
695
+ if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
696
+ self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
697
+
698
+ bad_settings = 0
699
+ for k, v in self.data.items():
700
+ info = self.data_labels.get(k, None)
701
+ if info is not None and not self.same_type(info.default, v):
702
+ print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
703
+ bad_settings += 1
704
+
705
+ if bad_settings > 0:
706
+ print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
707
+
708
+ def onchange(self, key, func, call=True):
709
+ item = self.data_labels.get(key)
710
+ item.onchange = func
711
+
712
+ if call:
713
+ func()
714
+
715
+ def dumpjson(self):
716
+ d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
717
+ d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
718
+ d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
719
+ return json.dumps(d)
720
+
721
+ def add_option(self, key, info):
722
+ self.data_labels[key] = info
723
+
724
+ def reorder(self):
725
+ """reorder settings so that all items related to section always go together"""
726
+
727
+ section_ids = {}
728
+ settings_items = self.data_labels.items()
729
+ for _, item in settings_items:
730
+ if item.section not in section_ids:
731
+ section_ids[item.section] = len(section_ids)
732
+
733
+ self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
734
+
735
+ def cast_value(self, key, value):
736
+ """casts an arbitrary to the same type as this setting's value with key
737
+ Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
738
+ """
739
+
740
+ if value is None:
741
+ return None
742
+
743
+ default_value = self.data_labels[key].default
744
+ if default_value is None:
745
+ default_value = getattr(self, key, None)
746
+ if default_value is None:
747
+ return None
748
+
749
+ expected_type = type(default_value)
750
+ if expected_type == bool and value == "False":
751
+ value = False
752
+ else:
753
+ value = expected_type(value)
754
+
755
+ return value
756
+
757
+
758
+ opts = Options()
759
+ if os.path.exists(config_filename):
760
+ opts.load(config_filename)
761
+
762
+
763
+ class Shared(sys.modules[__name__].__class__):
764
+ """
765
+ this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
766
+ at program startup.
767
+ """
768
+
769
+ sd_model_val = None
770
+
771
+ @property
772
+ def sd_model(self):
773
+ import modules.sd_models
774
+
775
+ return modules.sd_models.model_data.get_sd_model()
776
+
777
+ @sd_model.setter
778
+ def sd_model(self, value):
779
+ import modules.sd_models
780
+
781
+ modules.sd_models.model_data.set_sd_model(value)
782
+
783
+
784
+ sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
785
+ sys.modules[__name__].__class__ = Shared
786
+
787
+ settings_components = None
788
+ """assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
789
+
790
+ latent_upscale_default_mode = "Latent"
791
+ latent_upscale_modes = {
792
+ "Latent": {"mode": "bilinear", "antialias": False},
793
+ "Latent (antialiased)": {"mode": "bilinear", "antialias": True},
794
+ "Latent (bicubic)": {"mode": "bicubic", "antialias": False},
795
+ "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
796
+ "Latent (nearest)": {"mode": "nearest", "antialias": False},
797
+ "Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False},
798
+ }
799
+
800
+ sd_upscalers = []
801
+
802
+ clip_model = None
803
+
804
+ progress_print_out = sys.stdout
805
+
806
+ gradio_theme = gr.themes.Base()
807
+
808
+
809
+ def reload_gradio_theme(theme_name=None):
810
+ global gradio_theme
811
+ if not theme_name:
812
+ theme_name = opts.gradio_theme
813
+
814
+ default_theme_args = dict(
815
+ font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
816
+ font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
817
+ )
818
+
819
+ if theme_name == "Default":
820
+ gradio_theme = gr.themes.Default(**default_theme_args)
821
+ else:
822
+ try:
823
+ gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
824
+ except Exception as e:
825
+ errors.display(e, "changing gradio theme")
826
+ gradio_theme = gr.themes.Default(**default_theme_args)
827
+
828
+
829
+
830
+ class TotalTQDM:
831
+ def __init__(self):
832
+ self._tqdm = None
833
+
834
+ def reset(self):
835
+ self._tqdm = tqdm.tqdm(
836
+ desc="Total progress",
837
+ total=state.job_count * state.sampling_steps,
838
+ position=1,
839
+ file=progress_print_out
840
+ )
841
+
842
+ def update(self):
843
+ if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
844
+ return
845
+ if self._tqdm is None:
846
+ self.reset()
847
+ self._tqdm.update()
848
+
849
+ def updateTotal(self, new_total):
850
+ if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
851
+ return
852
+ if self._tqdm is None:
853
+ self.reset()
854
+ self._tqdm.total = new_total
855
+
856
+ def clear(self):
857
+ if self._tqdm is not None:
858
+ self._tqdm.refresh()
859
+ self._tqdm.close()
860
+ self._tqdm = None
861
+
862
+
863
+ total_tqdm = TotalTQDM()
864
+
865
+ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
866
+ mem_mon.start()
867
+
868
+
869
+ def natural_sort_key(s, regex=re.compile('([0-9]+)')):
870
+ return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
871
+
872
+
873
+ def listfiles(dirname):
874
+ filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
875
+ return [file for file in filenames if os.path.isfile(file)]
876
+
877
+
878
+ def html_path(filename):
879
+ return os.path.join(script_path, "html", filename)
880
+
881
+
882
+ def html(filename):
883
+ path = html_path(filename)
884
+
885
+ if os.path.exists(path):
886
+ with open(path, encoding="utf8") as file:
887
+ return file.read()
888
+
889
+ return ""
890
+
891
+
892
+ def walk_files(path, allowed_extensions=None):
893
+ if not os.path.exists(path):
894
+ return
895
+
896
+ if allowed_extensions is not None:
897
+ allowed_extensions = set(allowed_extensions)
898
+
899
+ items = list(os.walk(path, followlinks=True))
900
+ items = sorted(items, key=lambda x: natural_sort_key(x[0]))
901
+
902
+ for root, _, files in items:
903
+ for filename in sorted(files, key=natural_sort_key):
904
+ if allowed_extensions is not None:
905
+ _, ext = os.path.splitext(filename)
906
+ if ext not in allowed_extensions:
907
+ continue
908
+
909
+ if not opts.list_hidden_files and ("/." in root or "\\." in root):
910
+ continue
911
+
912
+ yield os.path.join(root, filename)
modules/shared_items.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ def realesrgan_models_names():
4
+ import modules.realesrgan_model
5
+ return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
6
+
7
+
8
+ def postprocessing_scripts():
9
+ import modules.scripts
10
+
11
+ return modules.scripts.scripts_postproc.scripts
12
+
13
+
14
+ def sd_vae_items():
15
+ import modules.sd_vae
16
+
17
+ return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
18
+
19
+
20
+ def refresh_vae_list():
21
+ import modules.sd_vae
22
+
23
+ modules.sd_vae.refresh_vae_list()
24
+
25
+
26
+ def cross_attention_optimizations():
27
+ import modules.sd_hijack
28
+
29
+ return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
30
+
31
+
32
+ def sd_unet_items():
33
+ import modules.sd_unet
34
+
35
+ return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"]
36
+
37
+
38
+ def refresh_unet_list():
39
+ import modules.sd_unet
40
+
41
+ modules.sd_unet.list_unets()
42
+
43
+
44
+ ui_reorder_categories_builtin_items = [
45
+ "inpaint",
46
+ "sampler",
47
+ "checkboxes",
48
+ "hires_fix",
49
+ "dimensions",
50
+ "cfg",
51
+ "seed",
52
+ "batch",
53
+ "override_settings",
54
+ ]
55
+
56
+
57
+ def ui_reorder_categories():
58
+ from modules import scripts
59
+
60
+ yield from ui_reorder_categories_builtin_items
61
+
62
+ sections = {}
63
+ for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
64
+ if isinstance(script.section, str):
65
+ sections[script.section] = 1
66
+
67
+ yield from sections
68
+
69
+ yield "scripts"
modules/styles.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ import os.path
4
+ import re
5
+ import typing
6
+ import shutil
7
+
8
+
9
+ class PromptStyle(typing.NamedTuple):
10
+ name: str
11
+ prompt: str
12
+ negative_prompt: str
13
+
14
+
15
+ def merge_prompts(style_prompt: str, prompt: str) -> str:
16
+ if "{prompt}" in style_prompt:
17
+ res = style_prompt.replace("{prompt}", prompt)
18
+ else:
19
+ parts = filter(None, (prompt.strip(), style_prompt.strip()))
20
+ res = ", ".join(parts)
21
+
22
+ return res
23
+
24
+
25
+ def apply_styles_to_prompt(prompt, styles):
26
+ for style in styles:
27
+ prompt = merge_prompts(style, prompt)
28
+
29
+ return prompt
30
+
31
+
32
+ re_spaces = re.compile(" +")
33
+
34
+
35
+ def extract_style_text_from_prompt(style_text, prompt):
36
+ stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
37
+ stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
38
+ if "{prompt}" in stripped_style_text:
39
+ left, right = stripped_style_text.split("{prompt}", 2)
40
+ if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
41
+ prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
42
+ return True, prompt
43
+ else:
44
+ if stripped_prompt.endswith(stripped_style_text):
45
+ prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
46
+
47
+ if prompt.endswith(', '):
48
+ prompt = prompt[:-2]
49
+
50
+ return True, prompt
51
+
52
+ return False, prompt
53
+
54
+
55
+ def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
56
+ if not style.prompt and not style.negative_prompt:
57
+ return False, prompt, negative_prompt
58
+
59
+ match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
60
+ if not match_positive:
61
+ return False, prompt, negative_prompt
62
+
63
+ match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
64
+ if not match_negative:
65
+ return False, prompt, negative_prompt
66
+
67
+ return True, extracted_positive, extracted_negative
68
+
69
+
70
+ class StyleDatabase:
71
+ def __init__(self, path: str):
72
+ self.no_style = PromptStyle("None", "", "")
73
+ self.styles = {}
74
+ self.path = path
75
+
76
+ self.reload()
77
+
78
+ def reload(self):
79
+ self.styles.clear()
80
+
81
+ if not os.path.exists(self.path):
82
+ return
83
+
84
+ with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
85
+ reader = csv.DictReader(file, skipinitialspace=True)
86
+ for row in reader:
87
+ # Support loading old CSV format with "name, text"-columns
88
+ prompt = row["prompt"] if "prompt" in row else row["text"]
89
+ negative_prompt = row.get("negative_prompt", "")
90
+ self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
91
+
92
+ def get_style_prompts(self, styles):
93
+ return [self.styles.get(x, self.no_style).prompt for x in styles]
94
+
95
+ def get_negative_style_prompts(self, styles):
96
+ return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
97
+
98
+ def apply_styles_to_prompt(self, prompt, styles):
99
+ return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
100
+
101
+ def apply_negative_styles_to_prompt(self, prompt, styles):
102
+ return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
103
+
104
+ def save_styles(self, path: str) -> None:
105
+ # Always keep a backup file around
106
+ if os.path.exists(path):
107
+ shutil.copy(path, f"{path}.bak")
108
+
109
+ fd = os.open(path, os.O_RDWR | os.O_CREAT)
110
+ with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
111
+ # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
112
+ # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
113
+ writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
114
+ writer.writeheader()
115
+ writer.writerows(style._asdict() for k, style in self.styles.items())
116
+
117
+ def extract_styles_from_prompt(self, prompt, negative_prompt):
118
+ extracted = []
119
+
120
+ applicable_styles = list(self.styles.values())
121
+
122
+ while True:
123
+ found_style = None
124
+
125
+ for style in applicable_styles:
126
+ is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
127
+ if is_match:
128
+ found_style = style
129
+ prompt = new_prompt
130
+ negative_prompt = new_neg_prompt
131
+ break
132
+
133
+ if not found_style:
134
+ break
135
+
136
+ applicable_styles.remove(found_style)
137
+ extracted.append(found_style.name)
138
+
139
+ return list(reversed(extracted)), prompt, negative_prompt
modules/sub_quadratic_attention.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original source:
2
+ # https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
3
+ # license:
4
+ # MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
5
+ # credit:
6
+ # Amin Rezaei (original author)
7
+ # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
8
+ # brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
9
+ # implementation of:
10
+ # Self-attention Does Not Need O(n2) Memory":
11
+ # https://arxiv.org/abs/2112.05682v2
12
+
13
+ from functools import partial
14
+ import torch
15
+ from torch import Tensor
16
+ from torch.utils.checkpoint import checkpoint
17
+ import math
18
+ from typing import Optional, NamedTuple, List
19
+
20
+
21
+ def narrow_trunc(
22
+ input: Tensor,
23
+ dim: int,
24
+ start: int,
25
+ length: int
26
+ ) -> Tensor:
27
+ return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
28
+
29
+
30
+ class AttnChunk(NamedTuple):
31
+ exp_values: Tensor
32
+ exp_weights_sum: Tensor
33
+ max_score: Tensor
34
+
35
+
36
+ class SummarizeChunk:
37
+ @staticmethod
38
+ def __call__(
39
+ query: Tensor,
40
+ key: Tensor,
41
+ value: Tensor,
42
+ ) -> AttnChunk: ...
43
+
44
+
45
+ class ComputeQueryChunkAttn:
46
+ @staticmethod
47
+ def __call__(
48
+ query: Tensor,
49
+ key: Tensor,
50
+ value: Tensor,
51
+ ) -> Tensor: ...
52
+
53
+
54
+ def _summarize_chunk(
55
+ query: Tensor,
56
+ key: Tensor,
57
+ value: Tensor,
58
+ scale: float,
59
+ ) -> AttnChunk:
60
+ attn_weights = torch.baddbmm(
61
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
62
+ query,
63
+ key.transpose(1,2),
64
+ alpha=scale,
65
+ beta=0,
66
+ )
67
+ max_score, _ = torch.max(attn_weights, -1, keepdim=True)
68
+ max_score = max_score.detach()
69
+ exp_weights = torch.exp(attn_weights - max_score)
70
+ exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
71
+ max_score = max_score.squeeze(-1)
72
+ return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
73
+
74
+
75
+ def _query_chunk_attention(
76
+ query: Tensor,
77
+ key: Tensor,
78
+ value: Tensor,
79
+ summarize_chunk: SummarizeChunk,
80
+ kv_chunk_size: int,
81
+ ) -> Tensor:
82
+ batch_x_heads, k_tokens, k_channels_per_head = key.shape
83
+ _, _, v_channels_per_head = value.shape
84
+
85
+ def chunk_scanner(chunk_idx: int) -> AttnChunk:
86
+ key_chunk = narrow_trunc(
87
+ key,
88
+ 1,
89
+ chunk_idx,
90
+ kv_chunk_size
91
+ )
92
+ value_chunk = narrow_trunc(
93
+ value,
94
+ 1,
95
+ chunk_idx,
96
+ kv_chunk_size
97
+ )
98
+ return summarize_chunk(query, key_chunk, value_chunk)
99
+
100
+ chunks: List[AttnChunk] = [
101
+ chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
102
+ ]
103
+ acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
104
+ chunk_values, chunk_weights, chunk_max = acc_chunk
105
+
106
+ global_max, _ = torch.max(chunk_max, 0, keepdim=True)
107
+ max_diffs = torch.exp(chunk_max - global_max)
108
+ chunk_values *= torch.unsqueeze(max_diffs, -1)
109
+ chunk_weights *= max_diffs
110
+
111
+ all_values = chunk_values.sum(dim=0)
112
+ all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
113
+ return all_values / all_weights
114
+
115
+
116
+ # TODO: refactor CrossAttention#get_attention_scores to share code with this
117
+ def _get_attention_scores_no_kv_chunking(
118
+ query: Tensor,
119
+ key: Tensor,
120
+ value: Tensor,
121
+ scale: float,
122
+ ) -> Tensor:
123
+ attn_scores = torch.baddbmm(
124
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
125
+ query,
126
+ key.transpose(1,2),
127
+ alpha=scale,
128
+ beta=0,
129
+ )
130
+ attn_probs = attn_scores.softmax(dim=-1)
131
+ del attn_scores
132
+ hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
133
+ return hidden_states_slice
134
+
135
+
136
+ class ScannedChunk(NamedTuple):
137
+ chunk_idx: int
138
+ attn_chunk: AttnChunk
139
+
140
+
141
+ def efficient_dot_product_attention(
142
+ query: Tensor,
143
+ key: Tensor,
144
+ value: Tensor,
145
+ query_chunk_size=1024,
146
+ kv_chunk_size: Optional[int] = None,
147
+ kv_chunk_size_min: Optional[int] = None,
148
+ use_checkpoint=True,
149
+ ):
150
+ """Computes efficient dot-product attention given query, key, and value.
151
+ This is efficient version of attention presented in
152
+ https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
153
+ Args:
154
+ query: queries for calculating attention with shape of
155
+ `[batch * num_heads, tokens, channels_per_head]`.
156
+ key: keys for calculating attention with shape of
157
+ `[batch * num_heads, tokens, channels_per_head]`.
158
+ value: values to be used in attention with shape of
159
+ `[batch * num_heads, tokens, channels_per_head]`.
160
+ query_chunk_size: int: query chunks size
161
+ kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
162
+ kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
163
+ use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
164
+ Returns:
165
+ Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
166
+ """
167
+ batch_x_heads, q_tokens, q_channels_per_head = query.shape
168
+ _, k_tokens, _ = key.shape
169
+ scale = q_channels_per_head ** -0.5
170
+
171
+ kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
172
+ if kv_chunk_size_min is not None:
173
+ kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
174
+
175
+ def get_query_chunk(chunk_idx: int) -> Tensor:
176
+ return narrow_trunc(
177
+ query,
178
+ 1,
179
+ chunk_idx,
180
+ min(query_chunk_size, q_tokens)
181
+ )
182
+
183
+ summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
184
+ summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
185
+ compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
186
+ _get_attention_scores_no_kv_chunking,
187
+ scale=scale
188
+ ) if k_tokens <= kv_chunk_size else (
189
+ # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
190
+ partial(
191
+ _query_chunk_attention,
192
+ kv_chunk_size=kv_chunk_size,
193
+ summarize_chunk=summarize_chunk,
194
+ )
195
+ )
196
+
197
+ if q_tokens <= query_chunk_size:
198
+ # fast-path for when there's just 1 query chunk
199
+ return compute_query_chunk_attn(
200
+ query=query,
201
+ key=key,
202
+ value=value,
203
+ )
204
+
205
+ res = torch.zeros_like(query)
206
+ for i in range(math.ceil(q_tokens / query_chunk_size)):
207
+ attn_scores = compute_query_chunk_attn(
208
+ query=get_query_chunk(i * query_chunk_size),
209
+ key=key,
210
+ value=value,
211
+ )
212
+
213
+ res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
214
+
215
+ return res
modules/sysinfo.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import traceback
5
+
6
+ import platform
7
+ import hashlib
8
+ import pkg_resources
9
+ import psutil
10
+ import re
11
+
12
+ import launch
13
+ from modules import paths_internal, timer
14
+
15
+ checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
16
+ environment_whitelist = {
17
+ "GIT",
18
+ "INDEX_URL",
19
+ "WEBUI_LAUNCH_LIVE_OUTPUT",
20
+ "GRADIO_ANALYTICS_ENABLED",
21
+ "PYTHONPATH",
22
+ "TORCH_INDEX_URL",
23
+ "TORCH_COMMAND",
24
+ "REQS_FILE",
25
+ "XFORMERS_PACKAGE",
26
+ "GFPGAN_PACKAGE",
27
+ "CLIP_PACKAGE",
28
+ "OPENCLIP_PACKAGE",
29
+ "STABLE_DIFFUSION_REPO",
30
+ "K_DIFFUSION_REPO",
31
+ "CODEFORMER_REPO",
32
+ "BLIP_REPO",
33
+ "STABLE_DIFFUSION_COMMIT_HASH",
34
+ "K_DIFFUSION_COMMIT_HASH",
35
+ "CODEFORMER_COMMIT_HASH",
36
+ "BLIP_COMMIT_HASH",
37
+ "COMMANDLINE_ARGS",
38
+ "IGNORE_CMD_ARGS_ERRORS",
39
+ }
40
+
41
+
42
+ def pretty_bytes(num, suffix="B"):
43
+ for unit in ["", "K", "M", "G", "T", "P", "E", "Z", "Y"]:
44
+ if abs(num) < 1024 or unit == 'Y':
45
+ return f"{num:.0f}{unit}{suffix}"
46
+ num /= 1024
47
+
48
+
49
+ def get():
50
+ res = get_dict()
51
+
52
+ text = json.dumps(res, ensure_ascii=False, indent=4)
53
+
54
+ h = hashlib.sha256(text.encode("utf8"))
55
+ text = text.replace(checksum_token, h.hexdigest())
56
+
57
+ return text
58
+
59
+
60
+ re_checksum = re.compile(r'"Checksum": "([0-9a-fA-F]{64})"')
61
+
62
+
63
+ def check(x):
64
+ m = re.search(re_checksum, x)
65
+ if not m:
66
+ return False
67
+
68
+ replaced = re.sub(re_checksum, f'"Checksum": "{checksum_token}"', x)
69
+
70
+ h = hashlib.sha256(replaced.encode("utf8"))
71
+ return h.hexdigest() == m.group(1)
72
+
73
+
74
+ def get_dict():
75
+ ram = psutil.virtual_memory()
76
+
77
+ res = {
78
+ "Platform": platform.platform(),
79
+ "Python": platform.python_version(),
80
+ "Version": launch.git_tag(),
81
+ "Commit": launch.commit_hash(),
82
+ "Script path": paths_internal.script_path,
83
+ "Data path": paths_internal.data_path,
84
+ "Extensions dir": paths_internal.extensions_dir,
85
+ "Checksum": checksum_token,
86
+ "Commandline": sys.argv,
87
+ "Torch env info": get_torch_sysinfo(),
88
+ "Exceptions": get_exceptions(),
89
+ "CPU": {
90
+ "model": platform.processor(),
91
+ "count logical": psutil.cpu_count(logical=True),
92
+ "count physical": psutil.cpu_count(logical=False),
93
+ },
94
+ "RAM": {
95
+ x: pretty_bytes(getattr(ram, x, 0)) for x in ["total", "used", "free", "active", "inactive", "buffers", "cached", "shared"] if getattr(ram, x, 0) != 0
96
+ },
97
+ "Extensions": get_extensions(enabled=True),
98
+ "Inactive extensions": get_extensions(enabled=False),
99
+ "Environment": get_environment(),
100
+ "Config": get_config(),
101
+ "Startup": timer.startup_record,
102
+ "Packages": sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]),
103
+ }
104
+
105
+ return res
106
+
107
+
108
+ def format_traceback(tb):
109
+ return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
110
+
111
+
112
+ def get_exceptions():
113
+ try:
114
+ from modules import errors
115
+
116
+ return [{"exception": str(e), "traceback": format_traceback(tb)} for e, tb in reversed(errors.exception_records)]
117
+ except Exception as e:
118
+ return str(e)
119
+
120
+
121
+ def get_environment():
122
+ return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
123
+
124
+
125
+ re_newline = re.compile(r"\r*\n")
126
+
127
+
128
+ def get_torch_sysinfo():
129
+ try:
130
+ import torch.utils.collect_env
131
+ info = torch.utils.collect_env.get_env_info()._asdict()
132
+
133
+ return {k: re.split(re_newline, str(v)) if "\n" in str(v) else v for k, v in info.items()}
134
+ except Exception as e:
135
+ return str(e)
136
+
137
+
138
+ def get_extensions(*, enabled):
139
+
140
+ try:
141
+ from modules import extensions
142
+
143
+ def to_json(x: extensions.Extension):
144
+ return {
145
+ "name": x.name,
146
+ "path": x.path,
147
+ "version": x.version,
148
+ "branch": x.branch,
149
+ "remote": x.remote,
150
+ }
151
+
152
+ return [to_json(x) for x in extensions.extensions if not x.is_builtin and x.enabled == enabled]
153
+ except Exception as e:
154
+ return str(e)
155
+
156
+
157
+ def get_config():
158
+ try:
159
+ from modules import shared
160
+ return shared.opts.data
161
+ except Exception as e:
162
+ return str(e)
modules/textual_inversion/__pycache__/autocrop.cpython-310.pyc ADDED
Binary file (8.73 kB). View file
 
modules/textual_inversion/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (9.71 kB). View file
 
modules/textual_inversion/__pycache__/image_embedding.cpython-310.pyc ADDED
Binary file (7.87 kB). View file
 
modules/textual_inversion/__pycache__/learn_schedule.cpython-310.pyc ADDED
Binary file (2.74 kB). View file
 
modules/textual_inversion/__pycache__/logging.cpython-310.pyc ADDED
Binary file (1.68 kB). View file
 
modules/textual_inversion/__pycache__/preprocess.cpython-310.pyc ADDED
Binary file (7.17 kB). View file
 
modules/textual_inversion/__pycache__/textual_inversion.cpython-310.pyc ADDED
Binary file (20.7 kB). View file
 
modules/textual_inversion/__pycache__/ui.cpython-310.pyc ADDED
Binary file (1.61 kB). View file
 
modules/textual_inversion/autocrop.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import requests
3
+ import os
4
+ import numpy as np
5
+ from PIL import ImageDraw
6
+
7
+ GREEN = "#0F0"
8
+ BLUE = "#00F"
9
+ RED = "#F00"
10
+
11
+
12
+ def crop_image(im, settings):
13
+ """ Intelligently crop an image to the subject matter """
14
+
15
+ scale_by = 1
16
+ if is_landscape(im.width, im.height):
17
+ scale_by = settings.crop_height / im.height
18
+ elif is_portrait(im.width, im.height):
19
+ scale_by = settings.crop_width / im.width
20
+ elif is_square(im.width, im.height):
21
+ if is_square(settings.crop_width, settings.crop_height):
22
+ scale_by = settings.crop_width / im.width
23
+ elif is_landscape(settings.crop_width, settings.crop_height):
24
+ scale_by = settings.crop_width / im.width
25
+ elif is_portrait(settings.crop_width, settings.crop_height):
26
+ scale_by = settings.crop_height / im.height
27
+
28
+
29
+ im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
30
+ im_debug = im.copy()
31
+
32
+ focus = focal_point(im_debug, settings)
33
+
34
+ # take the focal point and turn it into crop coordinates that try to center over the focal
35
+ # point but then get adjusted back into the frame
36
+ y_half = int(settings.crop_height / 2)
37
+ x_half = int(settings.crop_width / 2)
38
+
39
+ x1 = focus.x - x_half
40
+ if x1 < 0:
41
+ x1 = 0
42
+ elif x1 + settings.crop_width > im.width:
43
+ x1 = im.width - settings.crop_width
44
+
45
+ y1 = focus.y - y_half
46
+ if y1 < 0:
47
+ y1 = 0
48
+ elif y1 + settings.crop_height > im.height:
49
+ y1 = im.height - settings.crop_height
50
+
51
+ x2 = x1 + settings.crop_width
52
+ y2 = y1 + settings.crop_height
53
+
54
+ crop = [x1, y1, x2, y2]
55
+
56
+ results = []
57
+
58
+ results.append(im.crop(tuple(crop)))
59
+
60
+ if settings.annotate_image:
61
+ d = ImageDraw.Draw(im_debug)
62
+ rect = list(crop)
63
+ rect[2] -= 1
64
+ rect[3] -= 1
65
+ d.rectangle(rect, outline=GREEN)
66
+ results.append(im_debug)
67
+ if settings.destop_view_image:
68
+ im_debug.show()
69
+
70
+ return results
71
+
72
+ def focal_point(im, settings):
73
+ corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
74
+ entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
75
+ face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
76
+
77
+ pois = []
78
+
79
+ weight_pref_total = 0
80
+ if corner_points:
81
+ weight_pref_total += settings.corner_points_weight
82
+ if entropy_points:
83
+ weight_pref_total += settings.entropy_points_weight
84
+ if face_points:
85
+ weight_pref_total += settings.face_points_weight
86
+
87
+ corner_centroid = None
88
+ if corner_points:
89
+ corner_centroid = centroid(corner_points)
90
+ corner_centroid.weight = settings.corner_points_weight / weight_pref_total
91
+ pois.append(corner_centroid)
92
+
93
+ entropy_centroid = None
94
+ if entropy_points:
95
+ entropy_centroid = centroid(entropy_points)
96
+ entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
97
+ pois.append(entropy_centroid)
98
+
99
+ face_centroid = None
100
+ if face_points:
101
+ face_centroid = centroid(face_points)
102
+ face_centroid.weight = settings.face_points_weight / weight_pref_total
103
+ pois.append(face_centroid)
104
+
105
+ average_point = poi_average(pois, settings)
106
+
107
+ if settings.annotate_image:
108
+ d = ImageDraw.Draw(im)
109
+ max_size = min(im.width, im.height) * 0.07
110
+ if corner_centroid is not None:
111
+ color = BLUE
112
+ box = corner_centroid.bounding(max_size * corner_centroid.weight)
113
+ d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
114
+ d.ellipse(box, outline=color)
115
+ if len(corner_points) > 1:
116
+ for f in corner_points:
117
+ d.rectangle(f.bounding(4), outline=color)
118
+ if entropy_centroid is not None:
119
+ color = "#ff0"
120
+ box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
121
+ d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
122
+ d.ellipse(box, outline=color)
123
+ if len(entropy_points) > 1:
124
+ for f in entropy_points:
125
+ d.rectangle(f.bounding(4), outline=color)
126
+ if face_centroid is not None:
127
+ color = RED
128
+ box = face_centroid.bounding(max_size * face_centroid.weight)
129
+ d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color)
130
+ d.ellipse(box, outline=color)
131
+ if len(face_points) > 1:
132
+ for f in face_points:
133
+ d.rectangle(f.bounding(4), outline=color)
134
+
135
+ d.ellipse(average_point.bounding(max_size), outline=GREEN)
136
+
137
+ return average_point
138
+
139
+
140
+ def image_face_points(im, settings):
141
+ if settings.dnn_model_path is not None:
142
+ detector = cv2.FaceDetectorYN.create(
143
+ settings.dnn_model_path,
144
+ "",
145
+ (im.width, im.height),
146
+ 0.9, # score threshold
147
+ 0.3, # nms threshold
148
+ 5000 # keep top k before nms
149
+ )
150
+ faces = detector.detect(np.array(im))
151
+ results = []
152
+ if faces[1] is not None:
153
+ for face in faces[1]:
154
+ x = face[0]
155
+ y = face[1]
156
+ w = face[2]
157
+ h = face[3]
158
+ results.append(
159
+ PointOfInterest(
160
+ int(x + (w * 0.5)), # face focus left/right is center
161
+ int(y + (h * 0.33)), # face focus up/down is close to the top of the head
162
+ size = w,
163
+ weight = 1/len(faces[1])
164
+ )
165
+ )
166
+ return results
167
+ else:
168
+ np_im = np.array(im)
169
+ gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
170
+
171
+ tries = [
172
+ [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
173
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
174
+ [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
175
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
176
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
177
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
178
+ [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
179
+ [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
180
+ ]
181
+ for t in tries:
182
+ classifier = cv2.CascadeClassifier(t[0])
183
+ minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
184
+ try:
185
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
186
+ minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
187
+ except Exception:
188
+ continue
189
+
190
+ if faces:
191
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
192
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
193
+ return []
194
+
195
+
196
+ def image_corner_points(im, settings):
197
+ grayscale = im.convert("L")
198
+
199
+ # naive attempt at preventing focal points from collecting at watermarks near the bottom
200
+ gd = ImageDraw.Draw(grayscale)
201
+ gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
202
+
203
+ np_im = np.array(grayscale)
204
+
205
+ points = cv2.goodFeaturesToTrack(
206
+ np_im,
207
+ maxCorners=100,
208
+ qualityLevel=0.04,
209
+ minDistance=min(grayscale.width, grayscale.height)*0.06,
210
+ useHarrisDetector=False,
211
+ )
212
+
213
+ if points is None:
214
+ return []
215
+
216
+ focal_points = []
217
+ for point in points:
218
+ x, y = point.ravel()
219
+ focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
220
+
221
+ return focal_points
222
+
223
+
224
+ def image_entropy_points(im, settings):
225
+ landscape = im.height < im.width
226
+ portrait = im.height > im.width
227
+ if landscape:
228
+ move_idx = [0, 2]
229
+ move_max = im.size[0]
230
+ elif portrait:
231
+ move_idx = [1, 3]
232
+ move_max = im.size[1]
233
+ else:
234
+ return []
235
+
236
+ e_max = 0
237
+ crop_current = [0, 0, settings.crop_width, settings.crop_height]
238
+ crop_best = crop_current
239
+ while crop_current[move_idx[1]] < move_max:
240
+ crop = im.crop(tuple(crop_current))
241
+ e = image_entropy(crop)
242
+
243
+ if (e > e_max):
244
+ e_max = e
245
+ crop_best = list(crop_current)
246
+
247
+ crop_current[move_idx[0]] += 4
248
+ crop_current[move_idx[1]] += 4
249
+
250
+ x_mid = int(crop_best[0] + settings.crop_width/2)
251
+ y_mid = int(crop_best[1] + settings.crop_height/2)
252
+
253
+ return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
254
+
255
+
256
+ def image_entropy(im):
257
+ # greyscale image entropy
258
+ # band = np.asarray(im.convert("L"))
259
+ band = np.asarray(im.convert("1"), dtype=np.uint8)
260
+ hist, _ = np.histogram(band, bins=range(0, 256))
261
+ hist = hist[hist > 0]
262
+ return -np.log2(hist / hist.sum()).sum()
263
+
264
+
265
+ def centroid(pois):
266
+ x = [poi.x for poi in pois]
267
+ y = [poi.y for poi in pois]
268
+ return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
269
+
270
+
271
+ def poi_average(pois, settings):
272
+ weight = 0.0
273
+ x = 0.0
274
+ y = 0.0
275
+ for poi in pois:
276
+ weight += poi.weight
277
+ x += poi.x * poi.weight
278
+ y += poi.y * poi.weight
279
+ avg_x = round(weight and x / weight)
280
+ avg_y = round(weight and y / weight)
281
+
282
+ return PointOfInterest(avg_x, avg_y)
283
+
284
+
285
+ def is_landscape(w, h):
286
+ return w > h
287
+
288
+
289
+ def is_portrait(w, h):
290
+ return h > w
291
+
292
+
293
+ def is_square(w, h):
294
+ return w == h
295
+
296
+
297
+ def download_and_cache_models(dirname):
298
+ download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
299
+ model_file_name = 'face_detection_yunet.onnx'
300
+
301
+ os.makedirs(dirname, exist_ok=True)
302
+
303
+ cache_file = os.path.join(dirname, model_file_name)
304
+ if not os.path.exists(cache_file):
305
+ print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
306
+ response = requests.get(download_url)
307
+ with open(cache_file, "wb") as f:
308
+ f.write(response.content)
309
+
310
+ if os.path.exists(cache_file):
311
+ return cache_file
312
+ return None
313
+
314
+
315
+ class PointOfInterest:
316
+ def __init__(self, x, y, weight=1.0, size=10):
317
+ self.x = x
318
+ self.y = y
319
+ self.weight = weight
320
+ self.size = size
321
+
322
+ def bounding(self, size):
323
+ return [
324
+ self.x - size // 2,
325
+ self.y - size // 2,
326
+ self.x + size // 2,
327
+ self.y + size // 2
328
+ ]
329
+
330
+
331
+ class Settings:
332
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
333
+ self.crop_width = crop_width
334
+ self.crop_height = crop_height
335
+ self.corner_points_weight = corner_points_weight
336
+ self.entropy_points_weight = entropy_points_weight
337
+ self.face_points_weight = face_points_weight
338
+ self.annotate_image = annotate_image
339
+ self.destop_view_image = False
340
+ self.dnn_model_path = dnn_model_path
modules/textual_inversion/dataset.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import PIL
4
+ import torch
5
+ from PIL import Image
6
+ from torch.utils.data import Dataset, DataLoader, Sampler
7
+ from torchvision import transforms
8
+ from collections import defaultdict
9
+ from random import shuffle, choices
10
+
11
+ import random
12
+ import tqdm
13
+ from modules import devices, shared
14
+ import re
15
+
16
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
17
+
18
+ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
19
+
20
+
21
+ class DatasetEntry:
22
+ def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
23
+ self.filename = filename
24
+ self.filename_text = filename_text
25
+ self.weight = weight
26
+ self.latent_dist = latent_dist
27
+ self.latent_sample = latent_sample
28
+ self.cond = cond
29
+ self.cond_text = cond_text
30
+ self.pixel_values = pixel_values
31
+
32
+
33
+ class PersonalizedBase(Dataset):
34
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
35
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
36
+
37
+ self.placeholder_token = placeholder_token
38
+
39
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
40
+
41
+ self.dataset = []
42
+
43
+ with open(template_file, "r") as file:
44
+ lines = [x.strip() for x in file.readlines()]
45
+
46
+ self.lines = lines
47
+
48
+ assert data_root, 'dataset directory not specified'
49
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
50
+ assert os.listdir(data_root), "Dataset directory is empty"
51
+
52
+ self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
53
+
54
+ self.shuffle_tags = shuffle_tags
55
+ self.tag_drop_out = tag_drop_out
56
+ groups = defaultdict(list)
57
+
58
+ print("Preparing dataset...")
59
+ for path in tqdm.tqdm(self.image_paths):
60
+ alpha_channel = None
61
+ if shared.state.interrupted:
62
+ raise Exception("interrupted")
63
+ try:
64
+ image = Image.open(path)
65
+ #Currently does not work for single color transparency
66
+ #We would need to read image.info['transparency'] for that
67
+ if use_weight and 'A' in image.getbands():
68
+ alpha_channel = image.getchannel('A')
69
+ image = image.convert('RGB')
70
+ if not varsize:
71
+ image = image.resize((width, height), PIL.Image.BICUBIC)
72
+ except Exception:
73
+ continue
74
+
75
+ text_filename = f"{os.path.splitext(path)[0]}.txt"
76
+ filename = os.path.basename(path)
77
+
78
+ if os.path.exists(text_filename):
79
+ with open(text_filename, "r", encoding="utf8") as file:
80
+ filename_text = file.read()
81
+ else:
82
+ filename_text = os.path.splitext(filename)[0]
83
+ filename_text = re.sub(re_numbers_at_start, '', filename_text)
84
+ if re_word:
85
+ tokens = re_word.findall(filename_text)
86
+ filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
87
+
88
+ npimage = np.array(image).astype(np.uint8)
89
+ npimage = (npimage / 127.5 - 1.0).astype(np.float32)
90
+
91
+ torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
92
+ latent_sample = None
93
+
94
+ with devices.autocast():
95
+ latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
96
+
97
+ #Perform latent sampling, even for random sampling.
98
+ #We need the sample dimensions for the weights
99
+ if latent_sampling_method == "deterministic":
100
+ if isinstance(latent_dist, DiagonalGaussianDistribution):
101
+ # Works only for DiagonalGaussianDistribution
102
+ latent_dist.std = 0
103
+ else:
104
+ latent_sampling_method = "once"
105
+ latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
106
+
107
+ if use_weight and alpha_channel is not None:
108
+ channels, *latent_size = latent_sample.shape
109
+ weight_img = alpha_channel.resize(latent_size)
110
+ npweight = np.array(weight_img).astype(np.float32)
111
+ #Repeat for every channel in the latent sample
112
+ weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
113
+ #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
114
+ weight -= weight.min()
115
+ weight /= weight.mean()
116
+ elif use_weight:
117
+ #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
118
+ weight = torch.ones(latent_sample.shape)
119
+ else:
120
+ weight = None
121
+
122
+ if latent_sampling_method == "random":
123
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
124
+ else:
125
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
126
+
127
+ if not (self.tag_drop_out != 0 or self.shuffle_tags):
128
+ entry.cond_text = self.create_text(filename_text)
129
+
130
+ if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
131
+ with devices.autocast():
132
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
133
+ groups[image.size].append(len(self.dataset))
134
+ self.dataset.append(entry)
135
+ del torchdata
136
+ del latent_dist
137
+ del latent_sample
138
+ del weight
139
+
140
+ self.length = len(self.dataset)
141
+ self.groups = list(groups.values())
142
+ assert self.length > 0, "No images have been found in the dataset."
143
+ self.batch_size = min(batch_size, self.length)
144
+ self.gradient_step = min(gradient_step, self.length // self.batch_size)
145
+ self.latent_sampling_method = latent_sampling_method
146
+
147
+ if len(groups) > 1:
148
+ print("Buckets:")
149
+ for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
150
+ print(f" {w}x{h}: {len(ids)}")
151
+ print()
152
+
153
+ def create_text(self, filename_text):
154
+ text = random.choice(self.lines)
155
+ tags = filename_text.split(',')
156
+ if self.tag_drop_out != 0:
157
+ tags = [t for t in tags if random.random() > self.tag_drop_out]
158
+ if self.shuffle_tags:
159
+ random.shuffle(tags)
160
+ text = text.replace("[filewords]", ','.join(tags))
161
+ text = text.replace("[name]", self.placeholder_token)
162
+ return text
163
+
164
+ def __len__(self):
165
+ return self.length
166
+
167
+ def __getitem__(self, i):
168
+ entry = self.dataset[i]
169
+ if self.tag_drop_out != 0 or self.shuffle_tags:
170
+ entry.cond_text = self.create_text(entry.filename_text)
171
+ if self.latent_sampling_method == "random":
172
+ entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
173
+ return entry
174
+
175
+
176
+ class GroupedBatchSampler(Sampler):
177
+ def __init__(self, data_source: PersonalizedBase, batch_size: int):
178
+ super().__init__(data_source)
179
+
180
+ n = len(data_source)
181
+ self.groups = data_source.groups
182
+ self.len = n_batch = n // batch_size
183
+ expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
184
+ self.base = [int(e) // batch_size for e in expected]
185
+ self.n_rand_batches = nrb = n_batch - sum(self.base)
186
+ self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
187
+ self.batch_size = batch_size
188
+
189
+ def __len__(self):
190
+ return self.len
191
+
192
+ def __iter__(self):
193
+ b = self.batch_size
194
+
195
+ for g in self.groups:
196
+ shuffle(g)
197
+
198
+ batches = []
199
+ for g in self.groups:
200
+ batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
201
+ for _ in range(self.n_rand_batches):
202
+ rand_group = choices(self.groups, self.probs)[0]
203
+ batches.append(choices(rand_group, k=b))
204
+
205
+ shuffle(batches)
206
+
207
+ yield from batches
208
+
209
+
210
+ class PersonalizedDataLoader(DataLoader):
211
+ def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
212
+ super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
213
+ if latent_sampling_method == "random":
214
+ self.collate_fn = collate_wrapper_random
215
+ else:
216
+ self.collate_fn = collate_wrapper
217
+
218
+
219
+ class BatchLoader:
220
+ def __init__(self, data):
221
+ self.cond_text = [entry.cond_text for entry in data]
222
+ self.cond = [entry.cond for entry in data]
223
+ self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
224
+ if all(entry.weight is not None for entry in data):
225
+ self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
226
+ else:
227
+ self.weight = None
228
+ #self.emb_index = [entry.emb_index for entry in data]
229
+ #print(self.latent_sample.device)
230
+
231
+ def pin_memory(self):
232
+ self.latent_sample = self.latent_sample.pin_memory()
233
+ return self
234
+
235
+ def collate_wrapper(batch):
236
+ return BatchLoader(batch)
237
+
238
+ class BatchLoaderRandom(BatchLoader):
239
+ def __init__(self, data):
240
+ super().__init__(data)
241
+
242
+ def pin_memory(self):
243
+ return self
244
+
245
+ def collate_wrapper_random(batch):
246
+ return BatchLoaderRandom(batch)
modules/textual_inversion/image_embedding.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import warnings
4
+
5
+ import numpy as np
6
+ import zlib
7
+ from PIL import Image, ImageDraw
8
+ import torch
9
+
10
+
11
+ class EmbeddingEncoder(json.JSONEncoder):
12
+ def default(self, obj):
13
+ if isinstance(obj, torch.Tensor):
14
+ return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
15
+ return json.JSONEncoder.default(self, obj)
16
+
17
+
18
+ class EmbeddingDecoder(json.JSONDecoder):
19
+ def __init__(self, *args, **kwargs):
20
+ json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
21
+
22
+ def object_hook(self, d):
23
+ if 'TORCHTENSOR' in d:
24
+ return torch.from_numpy(np.array(d['TORCHTENSOR']))
25
+ return d
26
+
27
+
28
+ def embedding_to_b64(data):
29
+ d = json.dumps(data, cls=EmbeddingEncoder)
30
+ return base64.b64encode(d.encode())
31
+
32
+
33
+ def embedding_from_b64(data):
34
+ d = base64.b64decode(data)
35
+ return json.loads(d, cls=EmbeddingDecoder)
36
+
37
+
38
+ def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
39
+ while True:
40
+ seed = (a * seed + c) % m
41
+ yield seed % 255
42
+
43
+
44
+ def xor_block(block):
45
+ g = lcg()
46
+ randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
47
+ return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
48
+
49
+
50
+ def style_block(block, sequence):
51
+ im = Image.new('RGB', (block.shape[1], block.shape[0]))
52
+ draw = ImageDraw.Draw(im)
53
+ i = 0
54
+ for x in range(-6, im.size[0], 8):
55
+ for yi, y in enumerate(range(-6, im.size[1], 8)):
56
+ offset = 0
57
+ if yi % 2 == 0:
58
+ offset = 4
59
+ shade = sequence[i % len(sequence)]
60
+ i += 1
61
+ draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
62
+
63
+ fg = np.array(im).astype(np.uint8) & 0xF0
64
+
65
+ return block ^ fg
66
+
67
+
68
+ def insert_image_data_embed(image, data):
69
+ d = 3
70
+ data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
71
+ data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
72
+ data_np_high = data_np_ >> 4
73
+ data_np_low = data_np_ & 0x0F
74
+
75
+ h = image.size[1]
76
+ next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
77
+ next_size = next_size + ((h*d)-(next_size % (h*d)))
78
+
79
+ data_np_low = np.resize(data_np_low, next_size)
80
+ data_np_low = data_np_low.reshape((h, -1, d))
81
+
82
+ data_np_high = np.resize(data_np_high, next_size)
83
+ data_np_high = data_np_high.reshape((h, -1, d))
84
+
85
+ edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
86
+ edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
87
+
88
+ data_np_low = style_block(data_np_low, sequence=edge_style)
89
+ data_np_low = xor_block(data_np_low)
90
+ data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
91
+ data_np_high = xor_block(data_np_high)
92
+
93
+ im_low = Image.fromarray(data_np_low, mode='RGB')
94
+ im_high = Image.fromarray(data_np_high, mode='RGB')
95
+
96
+ background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
97
+ background.paste(im_low, (0, 0))
98
+ background.paste(image, (im_low.size[0]+1, 0))
99
+ background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
100
+
101
+ return background
102
+
103
+
104
+ def crop_black(img, tol=0):
105
+ mask = (img > tol).all(2)
106
+ mask0, mask1 = mask.any(0), mask.any(1)
107
+ col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
108
+ row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
109
+ return img[row_start:row_end, col_start:col_end]
110
+
111
+
112
+ def extract_image_data_embed(image):
113
+ d = 3
114
+ outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
115
+ black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
116
+ if black_cols[0].shape[0] < 2:
117
+ print('No Image data blocks found.')
118
+ return None
119
+
120
+ data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
121
+ data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
122
+
123
+ data_block_lower = xor_block(data_block_lower)
124
+ data_block_upper = xor_block(data_block_upper)
125
+
126
+ data_block = (data_block_upper << 4) | (data_block_lower)
127
+ data_block = data_block.flatten().tobytes()
128
+
129
+ data = zlib.decompress(data_block)
130
+ return json.loads(data, cls=EmbeddingDecoder)
131
+
132
+
133
+ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
134
+ from modules.images import get_font
135
+ if textfont:
136
+ warnings.warn(
137
+ 'passing in a textfont to caption_image_overlay is deprecated and does nothing',
138
+ DeprecationWarning,
139
+ stacklevel=2,
140
+ )
141
+ from math import cos
142
+
143
+ image = srcimage.copy()
144
+ fontsize = 32
145
+ factor = 1.5
146
+ gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
147
+ for y in range(image.size[1]):
148
+ mag = 1-cos(y/image.size[1]*factor)
149
+ mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
150
+ gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
151
+ image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
152
+
153
+ draw = ImageDraw.Draw(image)
154
+
155
+ font = get_font(fontsize)
156
+ padding = 10
157
+
158
+ _, _, w, h = draw.textbbox((0, 0), title, font=font)
159
+ fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
160
+ font = get_font(fontsize)
161
+ _, _, w, h = draw.textbbox((0, 0), title, font=font)
162
+ draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
163
+
164
+ _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
165
+ fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
166
+ _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
167
+ fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
168
+ _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
169
+ fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
170
+
171
+ font = get_font(min(fontsize_left, fontsize_mid, fontsize_right))
172
+
173
+ draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
174
+ draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
175
+ draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
176
+
177
+ return image
178
+
179
+
180
+ if __name__ == '__main__':
181
+
182
+ testEmbed = Image.open('test_embedding.png')
183
+ data = extract_image_data_embed(testEmbed)
184
+ assert data is not None
185
+
186
+ data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
187
+ assert data is not None
188
+
189
+ image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
190
+ cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
191
+
192
+ test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
193
+
194
+ embedded_image = insert_image_data_embed(cap_image, test_embed)
195
+
196
+ retrived_embed = extract_image_data_embed(embedded_image)
197
+
198
+ assert str(retrived_embed) == str(test_embed)
199
+
200
+ embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
201
+
202
+ assert embedded_image == embedded_image2
203
+
204
+ g = lcg()
205
+ shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
206
+
207
+ reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
208
+ 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
209
+ 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
210
+ 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
211
+ 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
212
+ 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
213
+ 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
214
+ 204, 86, 73, 222, 44, 198, 118, 240, 97]
215
+
216
+ assert shared_random == reference_random
217
+
218
+ hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
219
+
220
+ assert 12731374 == hunna_kay_random_sum
modules/textual_inversion/learn_schedule.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+
3
+
4
+ class LearnScheduleIterator:
5
+ def __init__(self, learn_rate, max_steps, cur_step=0):
6
+ """
7
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
8
+ """
9
+
10
+ pairs = learn_rate.split(',')
11
+ self.rates = []
12
+ self.it = 0
13
+ self.maxit = 0
14
+ try:
15
+ for pair in pairs:
16
+ if not pair.strip():
17
+ continue
18
+ tmp = pair.split(':')
19
+ if len(tmp) == 2:
20
+ step = int(tmp[1])
21
+ if step > cur_step:
22
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
23
+ self.maxit += 1
24
+ if step > max_steps:
25
+ return
26
+ elif step == -1:
27
+ self.rates.append((float(tmp[0]), max_steps))
28
+ self.maxit += 1
29
+ return
30
+ else:
31
+ self.rates.append((float(tmp[0]), max_steps))
32
+ self.maxit += 1
33
+ return
34
+ assert self.rates
35
+ except (ValueError, AssertionError) as e:
36
+ raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
37
+
38
+
39
+ def __iter__(self):
40
+ return self
41
+
42
+ def __next__(self):
43
+ if self.it < self.maxit:
44
+ self.it += 1
45
+ return self.rates[self.it - 1]
46
+ else:
47
+ raise StopIteration
48
+
49
+
50
+ class LearnRateScheduler:
51
+ def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
52
+ self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
53
+ (self.learn_rate, self.end_step) = next(self.schedules)
54
+ self.verbose = verbose
55
+
56
+ if self.verbose:
57
+ print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
58
+
59
+ self.finished = False
60
+
61
+ def step(self, step_number):
62
+ if step_number < self.end_step:
63
+ return False
64
+
65
+ try:
66
+ (self.learn_rate, self.end_step) = next(self.schedules)
67
+ except StopIteration:
68
+ self.finished = True
69
+ return False
70
+ return True
71
+
72
+ def apply(self, optimizer, step_number):
73
+ if not self.step(step_number):
74
+ return
75
+
76
+ if self.verbose:
77
+ tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
78
+
79
+ for pg in optimizer.param_groups:
80
+ pg['lr'] = self.learn_rate
81
+