e2079503703a6d257b7f98add64aa93c94c0f12610c7aca1e434fe98ffb56c3c
Browse files- modules/progress.py +129 -0
- modules/prompt_parser.py +439 -0
- modules/realesrgan_model.py +132 -0
- modules/restart.py +23 -0
- modules/safe.py +196 -0
- modules/script_callbacks.py +453 -0
- modules/script_loading.py +31 -0
- modules/scripts.py +680 -0
- modules/scripts_auto_postprocessing.py +42 -0
- modules/scripts_postprocessing.py +152 -0
- modules/sd_disable_initialization.py +93 -0
- modules/sd_hijack.py +346 -0
- modules/sd_hijack_checkpoint.py +46 -0
- modules/sd_hijack_clip.py +349 -0
- modules/sd_hijack_clip_old.py +82 -0
- modules/sd_hijack_inpainting.py +97 -0
- modules/sd_hijack_ip2p.py +10 -0
- modules/sd_hijack_open_clip.py +71 -0
- modules/sd_hijack_optimizations.py +668 -0
- modules/sd_hijack_unet.py +85 -0
- modules/sd_hijack_utils.py +28 -0
- modules/sd_hijack_xlmr.py +32 -0
- modules/sd_models.py +643 -0
- modules/sd_models_config.py +125 -0
- modules/sd_models_xl.py +99 -0
- modules/sd_samplers.py +56 -0
- modules/sd_samplers_common.py +95 -0
- modules/sd_samplers_compvis.py +224 -0
- modules/sd_samplers_kdiffusion.py +476 -0
- modules/sd_unet.py +92 -0
- modules/sd_vae.py +213 -0
- modules/sd_vae_approx.py +86 -0
- modules/sd_vae_taesd.py +88 -0
- modules/shared.py +912 -0
- modules/shared_items.py +69 -0
- modules/styles.py +139 -0
- modules/sub_quadratic_attention.py +215 -0
- modules/sysinfo.py +162 -0
- modules/textual_inversion/__pycache__/autocrop.cpython-310.pyc +0 -0
- modules/textual_inversion/__pycache__/dataset.cpython-310.pyc +0 -0
- modules/textual_inversion/__pycache__/image_embedding.cpython-310.pyc +0 -0
- modules/textual_inversion/__pycache__/learn_schedule.cpython-310.pyc +0 -0
- modules/textual_inversion/__pycache__/logging.cpython-310.pyc +0 -0
- modules/textual_inversion/__pycache__/preprocess.cpython-310.pyc +0 -0
- modules/textual_inversion/__pycache__/textual_inversion.cpython-310.pyc +0 -0
- modules/textual_inversion/__pycache__/ui.cpython-310.pyc +0 -0
- modules/textual_inversion/autocrop.py +340 -0
- modules/textual_inversion/dataset.py +246 -0
- modules/textual_inversion/image_embedding.py +220 -0
- modules/textual_inversion/learn_schedule.py +81 -0
modules/progress.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import time
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
|
8 |
+
from modules.shared import opts
|
9 |
+
|
10 |
+
import modules.shared as shared
|
11 |
+
|
12 |
+
|
13 |
+
current_task = None
|
14 |
+
pending_tasks = {}
|
15 |
+
finished_tasks = []
|
16 |
+
recorded_results = []
|
17 |
+
recorded_results_limit = 2
|
18 |
+
|
19 |
+
|
20 |
+
def start_task(id_task):
|
21 |
+
global current_task
|
22 |
+
|
23 |
+
current_task = id_task
|
24 |
+
pending_tasks.pop(id_task, None)
|
25 |
+
|
26 |
+
|
27 |
+
def finish_task(id_task):
|
28 |
+
global current_task
|
29 |
+
|
30 |
+
if current_task == id_task:
|
31 |
+
current_task = None
|
32 |
+
|
33 |
+
finished_tasks.append(id_task)
|
34 |
+
if len(finished_tasks) > 16:
|
35 |
+
finished_tasks.pop(0)
|
36 |
+
|
37 |
+
|
38 |
+
def record_results(id_task, res):
|
39 |
+
recorded_results.append((id_task, res))
|
40 |
+
if len(recorded_results) > recorded_results_limit:
|
41 |
+
recorded_results.pop(0)
|
42 |
+
|
43 |
+
|
44 |
+
def add_task_to_queue(id_job):
|
45 |
+
pending_tasks[id_job] = time.time()
|
46 |
+
|
47 |
+
|
48 |
+
class ProgressRequest(BaseModel):
|
49 |
+
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
50 |
+
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
|
51 |
+
|
52 |
+
|
53 |
+
class ProgressResponse(BaseModel):
|
54 |
+
active: bool = Field(title="Whether the task is being worked on right now")
|
55 |
+
queued: bool = Field(title="Whether the task is in queue")
|
56 |
+
completed: bool = Field(title="Whether the task has already finished")
|
57 |
+
progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
|
58 |
+
eta: float = Field(default=None, title="ETA in secs")
|
59 |
+
live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
|
60 |
+
id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
|
61 |
+
textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
62 |
+
|
63 |
+
|
64 |
+
def setup_progress_api(app):
|
65 |
+
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
|
66 |
+
|
67 |
+
|
68 |
+
def progressapi(req: ProgressRequest):
|
69 |
+
active = req.id_task == current_task
|
70 |
+
queued = req.id_task in pending_tasks
|
71 |
+
completed = req.id_task in finished_tasks
|
72 |
+
|
73 |
+
if not active:
|
74 |
+
return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
|
75 |
+
|
76 |
+
progress = 0
|
77 |
+
|
78 |
+
job_count, job_no = shared.state.job_count, shared.state.job_no
|
79 |
+
sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
|
80 |
+
|
81 |
+
if job_count > 0:
|
82 |
+
progress += job_no / job_count
|
83 |
+
if sampling_steps > 0 and job_count > 0:
|
84 |
+
progress += 1 / job_count * sampling_step / sampling_steps
|
85 |
+
|
86 |
+
progress = min(progress, 1)
|
87 |
+
|
88 |
+
elapsed_since_start = time.time() - shared.state.time_start
|
89 |
+
predicted_duration = elapsed_since_start / progress if progress > 0 else None
|
90 |
+
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
|
91 |
+
|
92 |
+
id_live_preview = req.id_live_preview
|
93 |
+
shared.state.set_current_image()
|
94 |
+
if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
|
95 |
+
image = shared.state.current_image
|
96 |
+
if image is not None:
|
97 |
+
buffered = io.BytesIO()
|
98 |
+
|
99 |
+
if opts.live_previews_image_format == "png":
|
100 |
+
# using optimize for large images takes an enormous amount of time
|
101 |
+
if max(*image.size) <= 256:
|
102 |
+
save_kwargs = {"optimize": True}
|
103 |
+
else:
|
104 |
+
save_kwargs = {"optimize": False, "compress_level": 1}
|
105 |
+
|
106 |
+
else:
|
107 |
+
save_kwargs = {}
|
108 |
+
|
109 |
+
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
|
110 |
+
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
111 |
+
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
|
112 |
+
id_live_preview = shared.state.id_live_preview
|
113 |
+
else:
|
114 |
+
live_preview = None
|
115 |
+
else:
|
116 |
+
live_preview = None
|
117 |
+
|
118 |
+
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
119 |
+
|
120 |
+
|
121 |
+
def restore_progress(id_task):
|
122 |
+
while id_task == current_task or id_task in pending_tasks:
|
123 |
+
time.sleep(0.1)
|
124 |
+
|
125 |
+
res = next(iter([x[1] for x in recorded_results if id_task == x[0]]), None)
|
126 |
+
if res is not None:
|
127 |
+
return res
|
128 |
+
|
129 |
+
return gr.update(), gr.update(), gr.update(), f"Couldn't restore progress for {id_task}: results either have been discarded or never were obtained"
|
modules/prompt_parser.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import re
|
4 |
+
from collections import namedtuple
|
5 |
+
from typing import List
|
6 |
+
import lark
|
7 |
+
|
8 |
+
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
|
9 |
+
# will be represented with prompt_schedule like this (assuming steps=100):
|
10 |
+
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
|
11 |
+
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
|
12 |
+
# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
|
13 |
+
# [75, 'fantasy landscape with a lake and an oak in background masterful']
|
14 |
+
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
|
15 |
+
|
16 |
+
schedule_parser = lark.Lark(r"""
|
17 |
+
!start: (prompt | /[][():]/+)*
|
18 |
+
prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
|
19 |
+
!emphasized: "(" prompt ")"
|
20 |
+
| "(" prompt ":" prompt ")"
|
21 |
+
| "[" prompt "]"
|
22 |
+
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
23 |
+
alternate: "[" prompt ("|" prompt)+ "]"
|
24 |
+
WHITESPACE: /\s+/
|
25 |
+
plain: /([^\\\[\]():|]|\\.)+/
|
26 |
+
%import common.SIGNED_NUMBER -> NUMBER
|
27 |
+
""")
|
28 |
+
|
29 |
+
def get_learned_conditioning_prompt_schedules(prompts, steps):
|
30 |
+
"""
|
31 |
+
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
32 |
+
>>> g("test")
|
33 |
+
[[10, 'test']]
|
34 |
+
>>> g("a [b:3]")
|
35 |
+
[[3, 'a '], [10, 'a b']]
|
36 |
+
>>> g("a [b: 3]")
|
37 |
+
[[3, 'a '], [10, 'a b']]
|
38 |
+
>>> g("a [[[b]]:2]")
|
39 |
+
[[2, 'a '], [10, 'a [[b]]']]
|
40 |
+
>>> g("[(a:2):3]")
|
41 |
+
[[3, ''], [10, '(a:2)']]
|
42 |
+
>>> g("a [b : c : 1] d")
|
43 |
+
[[1, 'a b d'], [10, 'a c d']]
|
44 |
+
>>> g("a[b:[c:d:2]:1]e")
|
45 |
+
[[1, 'abe'], [2, 'ace'], [10, 'ade']]
|
46 |
+
>>> g("a [unbalanced")
|
47 |
+
[[10, 'a [unbalanced']]
|
48 |
+
>>> g("a [b:.5] c")
|
49 |
+
[[5, 'a c'], [10, 'a b c']]
|
50 |
+
>>> g("a [{b|d{:.5] c") # not handling this right now
|
51 |
+
[[5, 'a c'], [10, 'a {b|d{ c']]
|
52 |
+
>>> g("((a][:b:c [d:3]")
|
53 |
+
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
54 |
+
>>> g("[a|(b:1.1)]")
|
55 |
+
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
56 |
+
"""
|
57 |
+
|
58 |
+
def collect_steps(steps, tree):
|
59 |
+
res = [steps]
|
60 |
+
|
61 |
+
class CollectSteps(lark.Visitor):
|
62 |
+
def scheduled(self, tree):
|
63 |
+
tree.children[-1] = float(tree.children[-1])
|
64 |
+
if tree.children[-1] < 1:
|
65 |
+
tree.children[-1] *= steps
|
66 |
+
tree.children[-1] = min(steps, int(tree.children[-1]))
|
67 |
+
res.append(tree.children[-1])
|
68 |
+
|
69 |
+
def alternate(self, tree):
|
70 |
+
res.extend(range(1, steps+1))
|
71 |
+
|
72 |
+
CollectSteps().visit(tree)
|
73 |
+
return sorted(set(res))
|
74 |
+
|
75 |
+
def at_step(step, tree):
|
76 |
+
class AtStep(lark.Transformer):
|
77 |
+
def scheduled(self, args):
|
78 |
+
before, after, _, when = args
|
79 |
+
yield before or () if step <= when else after
|
80 |
+
def alternate(self, args):
|
81 |
+
yield next(args[(step - 1)%len(args)])
|
82 |
+
def start(self, args):
|
83 |
+
def flatten(x):
|
84 |
+
if type(x) == str:
|
85 |
+
yield x
|
86 |
+
else:
|
87 |
+
for gen in x:
|
88 |
+
yield from flatten(gen)
|
89 |
+
return ''.join(flatten(args))
|
90 |
+
def plain(self, args):
|
91 |
+
yield args[0].value
|
92 |
+
def __default__(self, data, children, meta):
|
93 |
+
for child in children:
|
94 |
+
yield child
|
95 |
+
return AtStep().transform(tree)
|
96 |
+
|
97 |
+
def get_schedule(prompt):
|
98 |
+
try:
|
99 |
+
tree = schedule_parser.parse(prompt)
|
100 |
+
except lark.exceptions.LarkError:
|
101 |
+
if 0:
|
102 |
+
import traceback
|
103 |
+
traceback.print_exc()
|
104 |
+
return [[steps, prompt]]
|
105 |
+
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
|
106 |
+
|
107 |
+
promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
|
108 |
+
return [promptdict[prompt] for prompt in prompts]
|
109 |
+
|
110 |
+
|
111 |
+
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
112 |
+
|
113 |
+
|
114 |
+
class SdConditioning(list):
|
115 |
+
"""
|
116 |
+
A list with prompts for stable diffusion's conditioner model.
|
117 |
+
Can also specify width and height of created image - SDXL needs it.
|
118 |
+
"""
|
119 |
+
def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
|
120 |
+
super().__init__()
|
121 |
+
self.extend(prompts)
|
122 |
+
|
123 |
+
if copy_from is None:
|
124 |
+
copy_from = prompts
|
125 |
+
|
126 |
+
self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
|
127 |
+
self.width = width or getattr(copy_from, 'width', None)
|
128 |
+
self.height = height or getattr(copy_from, 'height', None)
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
|
133 |
+
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
134 |
+
and the sampling step at which this condition is to be replaced by the next one.
|
135 |
+
|
136 |
+
Input:
|
137 |
+
(model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
|
138 |
+
|
139 |
+
Output:
|
140 |
+
[
|
141 |
+
[
|
142 |
+
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
|
143 |
+
],
|
144 |
+
[
|
145 |
+
ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
|
146 |
+
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
|
147 |
+
]
|
148 |
+
]
|
149 |
+
"""
|
150 |
+
res = []
|
151 |
+
|
152 |
+
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
153 |
+
cache = {}
|
154 |
+
|
155 |
+
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
|
156 |
+
|
157 |
+
cached = cache.get(prompt, None)
|
158 |
+
if cached is not None:
|
159 |
+
res.append(cached)
|
160 |
+
continue
|
161 |
+
|
162 |
+
texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
|
163 |
+
conds = model.get_learned_conditioning(texts)
|
164 |
+
|
165 |
+
cond_schedule = []
|
166 |
+
for i, (end_at_step, _) in enumerate(prompt_schedule):
|
167 |
+
if isinstance(conds, dict):
|
168 |
+
cond = {k: v[i] for k, v in conds.items()}
|
169 |
+
else:
|
170 |
+
cond = conds[i]
|
171 |
+
|
172 |
+
cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
|
173 |
+
|
174 |
+
cache[prompt] = cond_schedule
|
175 |
+
res.append(cond_schedule)
|
176 |
+
|
177 |
+
return res
|
178 |
+
|
179 |
+
|
180 |
+
re_AND = re.compile(r"\bAND\b")
|
181 |
+
re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
|
182 |
+
|
183 |
+
|
184 |
+
def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
|
185 |
+
res_indexes = []
|
186 |
+
|
187 |
+
prompt_indexes = {}
|
188 |
+
prompt_flat_list = SdConditioning(prompts)
|
189 |
+
prompt_flat_list.clear()
|
190 |
+
|
191 |
+
for prompt in prompts:
|
192 |
+
subprompts = re_AND.split(prompt)
|
193 |
+
|
194 |
+
indexes = []
|
195 |
+
for subprompt in subprompts:
|
196 |
+
match = re_weight.search(subprompt)
|
197 |
+
|
198 |
+
text, weight = match.groups() if match is not None else (subprompt, 1.0)
|
199 |
+
|
200 |
+
weight = float(weight) if weight is not None else 1.0
|
201 |
+
|
202 |
+
index = prompt_indexes.get(text, None)
|
203 |
+
if index is None:
|
204 |
+
index = len(prompt_flat_list)
|
205 |
+
prompt_flat_list.append(text)
|
206 |
+
prompt_indexes[text] = index
|
207 |
+
|
208 |
+
indexes.append((index, weight))
|
209 |
+
|
210 |
+
res_indexes.append(indexes)
|
211 |
+
|
212 |
+
return res_indexes, prompt_flat_list, prompt_indexes
|
213 |
+
|
214 |
+
|
215 |
+
class ComposableScheduledPromptConditioning:
|
216 |
+
def __init__(self, schedules, weight=1.0):
|
217 |
+
self.schedules: List[ScheduledPromptConditioning] = schedules
|
218 |
+
self.weight: float = weight
|
219 |
+
|
220 |
+
|
221 |
+
class MulticondLearnedConditioning:
|
222 |
+
def __init__(self, shape, batch):
|
223 |
+
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
224 |
+
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
225 |
+
|
226 |
+
|
227 |
+
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
228 |
+
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
229 |
+
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
230 |
+
|
231 |
+
https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
|
232 |
+
"""
|
233 |
+
|
234 |
+
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
235 |
+
|
236 |
+
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
|
237 |
+
|
238 |
+
res = []
|
239 |
+
for indexes in res_indexes:
|
240 |
+
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
|
241 |
+
|
242 |
+
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
243 |
+
|
244 |
+
|
245 |
+
class DictWithShape(dict):
|
246 |
+
def __init__(self, x, shape):
|
247 |
+
super().__init__()
|
248 |
+
self.update(x)
|
249 |
+
|
250 |
+
@property
|
251 |
+
def shape(self):
|
252 |
+
return self["crossattn"].shape
|
253 |
+
|
254 |
+
|
255 |
+
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
|
256 |
+
param = c[0][0].cond
|
257 |
+
is_dict = isinstance(param, dict)
|
258 |
+
|
259 |
+
if is_dict:
|
260 |
+
dict_cond = param
|
261 |
+
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
|
262 |
+
res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
|
263 |
+
else:
|
264 |
+
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
265 |
+
|
266 |
+
for i, cond_schedule in enumerate(c):
|
267 |
+
target_index = 0
|
268 |
+
for current, entry in enumerate(cond_schedule):
|
269 |
+
if current_step <= entry.end_at_step:
|
270 |
+
target_index = current
|
271 |
+
break
|
272 |
+
|
273 |
+
if is_dict:
|
274 |
+
for k, param in cond_schedule[target_index].cond.items():
|
275 |
+
res[k][i] = param
|
276 |
+
else:
|
277 |
+
res[i] = cond_schedule[target_index].cond
|
278 |
+
|
279 |
+
return res
|
280 |
+
|
281 |
+
|
282 |
+
def stack_conds(tensors):
|
283 |
+
# if prompts have wildly different lengths above the limit we'll get tensors of different shapes
|
284 |
+
# and won't be able to torch.stack them. So this fixes that.
|
285 |
+
token_count = max([x.shape[0] for x in tensors])
|
286 |
+
for i in range(len(tensors)):
|
287 |
+
if tensors[i].shape[0] != token_count:
|
288 |
+
last_vector = tensors[i][-1:]
|
289 |
+
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
290 |
+
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
291 |
+
|
292 |
+
return torch.stack(tensors)
|
293 |
+
|
294 |
+
|
295 |
+
|
296 |
+
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
297 |
+
param = c.batch[0][0].schedules[0].cond
|
298 |
+
|
299 |
+
tensors = []
|
300 |
+
conds_list = []
|
301 |
+
|
302 |
+
for composable_prompts in c.batch:
|
303 |
+
conds_for_batch = []
|
304 |
+
|
305 |
+
for composable_prompt in composable_prompts:
|
306 |
+
target_index = 0
|
307 |
+
for current, entry in enumerate(composable_prompt.schedules):
|
308 |
+
if current_step <= entry.end_at_step:
|
309 |
+
target_index = current
|
310 |
+
break
|
311 |
+
|
312 |
+
conds_for_batch.append((len(tensors), composable_prompt.weight))
|
313 |
+
tensors.append(composable_prompt.schedules[target_index].cond)
|
314 |
+
|
315 |
+
conds_list.append(conds_for_batch)
|
316 |
+
|
317 |
+
if isinstance(tensors[0], dict):
|
318 |
+
keys = list(tensors[0].keys())
|
319 |
+
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
|
320 |
+
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
|
321 |
+
else:
|
322 |
+
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
|
323 |
+
|
324 |
+
return conds_list, stacked
|
325 |
+
|
326 |
+
|
327 |
+
re_attention = re.compile(r"""
|
328 |
+
\\\(|
|
329 |
+
\\\)|
|
330 |
+
\\\[|
|
331 |
+
\\]|
|
332 |
+
\\\\|
|
333 |
+
\\|
|
334 |
+
\(|
|
335 |
+
\[|
|
336 |
+
:([+-]?[.\d]+)\)|
|
337 |
+
\)|
|
338 |
+
]|
|
339 |
+
[^\\()\[\]:]+|
|
340 |
+
:
|
341 |
+
""", re.X)
|
342 |
+
|
343 |
+
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
|
344 |
+
|
345 |
+
def parse_prompt_attention(text):
|
346 |
+
"""
|
347 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
348 |
+
Accepted tokens are:
|
349 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
350 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
351 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
352 |
+
\( - literal character '('
|
353 |
+
\[ - literal character '['
|
354 |
+
\) - literal character ')'
|
355 |
+
\] - literal character ']'
|
356 |
+
\\ - literal character '\'
|
357 |
+
anything else - just text
|
358 |
+
|
359 |
+
>>> parse_prompt_attention('normal text')
|
360 |
+
[['normal text', 1.0]]
|
361 |
+
>>> parse_prompt_attention('an (important) word')
|
362 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
363 |
+
>>> parse_prompt_attention('(unbalanced')
|
364 |
+
[['unbalanced', 1.1]]
|
365 |
+
>>> parse_prompt_attention('\(literal\]')
|
366 |
+
[['(literal]', 1.0]]
|
367 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
368 |
+
[['unnecessaryparens', 1.1]]
|
369 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
370 |
+
[['a ', 1.0],
|
371 |
+
['house', 1.5730000000000004],
|
372 |
+
[' ', 1.1],
|
373 |
+
['on', 1.0],
|
374 |
+
[' a ', 1.1],
|
375 |
+
['hill', 0.55],
|
376 |
+
[', sun, ', 1.1],
|
377 |
+
['sky', 1.4641000000000006],
|
378 |
+
['.', 1.1]]
|
379 |
+
"""
|
380 |
+
|
381 |
+
res = []
|
382 |
+
round_brackets = []
|
383 |
+
square_brackets = []
|
384 |
+
|
385 |
+
round_bracket_multiplier = 1.1
|
386 |
+
square_bracket_multiplier = 1 / 1.1
|
387 |
+
|
388 |
+
def multiply_range(start_position, multiplier):
|
389 |
+
for p in range(start_position, len(res)):
|
390 |
+
res[p][1] *= multiplier
|
391 |
+
|
392 |
+
for m in re_attention.finditer(text):
|
393 |
+
text = m.group(0)
|
394 |
+
weight = m.group(1)
|
395 |
+
|
396 |
+
if text.startswith('\\'):
|
397 |
+
res.append([text[1:], 1.0])
|
398 |
+
elif text == '(':
|
399 |
+
round_brackets.append(len(res))
|
400 |
+
elif text == '[':
|
401 |
+
square_brackets.append(len(res))
|
402 |
+
elif weight is not None and round_brackets:
|
403 |
+
multiply_range(round_brackets.pop(), float(weight))
|
404 |
+
elif text == ')' and round_brackets:
|
405 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
406 |
+
elif text == ']' and square_brackets:
|
407 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
408 |
+
else:
|
409 |
+
parts = re.split(re_break, text)
|
410 |
+
for i, part in enumerate(parts):
|
411 |
+
if i > 0:
|
412 |
+
res.append(["BREAK", -1])
|
413 |
+
res.append([part, 1.0])
|
414 |
+
|
415 |
+
for pos in round_brackets:
|
416 |
+
multiply_range(pos, round_bracket_multiplier)
|
417 |
+
|
418 |
+
for pos in square_brackets:
|
419 |
+
multiply_range(pos, square_bracket_multiplier)
|
420 |
+
|
421 |
+
if len(res) == 0:
|
422 |
+
res = [["", 1.0]]
|
423 |
+
|
424 |
+
# merge runs of identical weights
|
425 |
+
i = 0
|
426 |
+
while i + 1 < len(res):
|
427 |
+
if res[i][1] == res[i + 1][1]:
|
428 |
+
res[i][0] += res[i + 1][0]
|
429 |
+
res.pop(i + 1)
|
430 |
+
else:
|
431 |
+
i += 1
|
432 |
+
|
433 |
+
return res
|
434 |
+
|
435 |
+
if __name__ == "__main__":
|
436 |
+
import doctest
|
437 |
+
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
|
438 |
+
else:
|
439 |
+
import torch # doctest faster
|
modules/realesrgan_model.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from realesrgan import RealESRGANer
|
6 |
+
|
7 |
+
from modules.upscaler import Upscaler, UpscalerData
|
8 |
+
from modules.shared import cmd_opts, opts
|
9 |
+
from modules import modelloader, errors
|
10 |
+
|
11 |
+
|
12 |
+
class UpscalerRealESRGAN(Upscaler):
|
13 |
+
def __init__(self, path):
|
14 |
+
self.name = "RealESRGAN"
|
15 |
+
self.user_path = path
|
16 |
+
super().__init__()
|
17 |
+
try:
|
18 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
|
19 |
+
from realesrgan import RealESRGANer # noqa: F401
|
20 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
|
21 |
+
self.enable = True
|
22 |
+
self.scalers = []
|
23 |
+
scalers = self.load_models(path)
|
24 |
+
|
25 |
+
local_model_paths = self.find_models(ext_filter=[".pth"])
|
26 |
+
for scaler in scalers:
|
27 |
+
if scaler.local_data_path.startswith("http"):
|
28 |
+
filename = modelloader.friendly_name(scaler.local_data_path)
|
29 |
+
local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
|
30 |
+
if local_model_candidates:
|
31 |
+
scaler.local_data_path = local_model_candidates[0]
|
32 |
+
|
33 |
+
if scaler.name in opts.realesrgan_enabled_models:
|
34 |
+
self.scalers.append(scaler)
|
35 |
+
|
36 |
+
except Exception:
|
37 |
+
errors.report("Error importing Real-ESRGAN", exc_info=True)
|
38 |
+
self.enable = False
|
39 |
+
self.scalers = []
|
40 |
+
|
41 |
+
def do_upscale(self, img, path):
|
42 |
+
if not self.enable:
|
43 |
+
return img
|
44 |
+
|
45 |
+
try:
|
46 |
+
info = self.load_model(path)
|
47 |
+
except Exception:
|
48 |
+
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
|
49 |
+
return img
|
50 |
+
|
51 |
+
upsampler = RealESRGANer(
|
52 |
+
scale=info.scale,
|
53 |
+
model_path=info.local_data_path,
|
54 |
+
model=info.model(),
|
55 |
+
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
56 |
+
tile=opts.ESRGAN_tile,
|
57 |
+
tile_pad=opts.ESRGAN_tile_overlap,
|
58 |
+
)
|
59 |
+
|
60 |
+
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
61 |
+
|
62 |
+
image = Image.fromarray(upsampled)
|
63 |
+
return image
|
64 |
+
|
65 |
+
def load_model(self, path):
|
66 |
+
for scaler in self.scalers:
|
67 |
+
if scaler.data_path == path:
|
68 |
+
if scaler.local_data_path.startswith("http"):
|
69 |
+
scaler.local_data_path = modelloader.load_file_from_url(
|
70 |
+
scaler.data_path,
|
71 |
+
model_dir=self.model_download_path,
|
72 |
+
)
|
73 |
+
if not os.path.exists(scaler.local_data_path):
|
74 |
+
raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
|
75 |
+
return scaler
|
76 |
+
raise ValueError(f"Unable to find model info: {path}")
|
77 |
+
|
78 |
+
def load_models(self, _):
|
79 |
+
return get_realesrgan_models(self)
|
80 |
+
|
81 |
+
|
82 |
+
def get_realesrgan_models(scaler):
|
83 |
+
try:
|
84 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
85 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
86 |
+
models = [
|
87 |
+
UpscalerData(
|
88 |
+
name="R-ESRGAN General 4xV3",
|
89 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
90 |
+
scale=4,
|
91 |
+
upscaler=scaler,
|
92 |
+
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
93 |
+
),
|
94 |
+
UpscalerData(
|
95 |
+
name="R-ESRGAN General WDN 4xV3",
|
96 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
97 |
+
scale=4,
|
98 |
+
upscaler=scaler,
|
99 |
+
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
100 |
+
),
|
101 |
+
UpscalerData(
|
102 |
+
name="R-ESRGAN AnimeVideo",
|
103 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
104 |
+
scale=4,
|
105 |
+
upscaler=scaler,
|
106 |
+
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
107 |
+
),
|
108 |
+
UpscalerData(
|
109 |
+
name="R-ESRGAN 4x+",
|
110 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
111 |
+
scale=4,
|
112 |
+
upscaler=scaler,
|
113 |
+
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
114 |
+
),
|
115 |
+
UpscalerData(
|
116 |
+
name="R-ESRGAN 4x+ Anime6B",
|
117 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
118 |
+
scale=4,
|
119 |
+
upscaler=scaler,
|
120 |
+
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
121 |
+
),
|
122 |
+
UpscalerData(
|
123 |
+
name="R-ESRGAN 2x+",
|
124 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
125 |
+
scale=2,
|
126 |
+
upscaler=scaler,
|
127 |
+
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
128 |
+
),
|
129 |
+
]
|
130 |
+
return models
|
131 |
+
except Exception:
|
132 |
+
errors.report("Error making Real-ESRGAN models list", exc_info=True)
|
modules/restart.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from modules.paths_internal import script_path
|
5 |
+
|
6 |
+
|
7 |
+
def is_restartable() -> bool:
|
8 |
+
"""
|
9 |
+
Return True if the webui is restartable (i.e. there is something watching to restart it with)
|
10 |
+
"""
|
11 |
+
return bool(os.environ.get('SD_WEBUI_RESTART'))
|
12 |
+
|
13 |
+
|
14 |
+
def restart_program() -> None:
|
15 |
+
"""creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
|
16 |
+
|
17 |
+
(Path(script_path) / "tmp" / "restart").touch()
|
18 |
+
|
19 |
+
stop_program()
|
20 |
+
|
21 |
+
|
22 |
+
def stop_program() -> None:
|
23 |
+
os._exit(0)
|
modules/safe.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this code is adapted from the script contributed by anon from /h/
|
2 |
+
|
3 |
+
import pickle
|
4 |
+
import collections
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy
|
8 |
+
import _codecs
|
9 |
+
import zipfile
|
10 |
+
import re
|
11 |
+
|
12 |
+
|
13 |
+
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
14 |
+
from modules import errors
|
15 |
+
|
16 |
+
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
17 |
+
|
18 |
+
def encode(*args):
|
19 |
+
out = _codecs.encode(*args)
|
20 |
+
return out
|
21 |
+
|
22 |
+
|
23 |
+
class RestrictedUnpickler(pickle.Unpickler):
|
24 |
+
extra_handler = None
|
25 |
+
|
26 |
+
def persistent_load(self, saved_id):
|
27 |
+
assert saved_id[0] == 'storage'
|
28 |
+
|
29 |
+
try:
|
30 |
+
return TypedStorage(_internal=True)
|
31 |
+
except TypeError:
|
32 |
+
return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
|
33 |
+
|
34 |
+
def find_class(self, module, name):
|
35 |
+
if self.extra_handler is not None:
|
36 |
+
res = self.extra_handler(module, name)
|
37 |
+
if res is not None:
|
38 |
+
return res
|
39 |
+
|
40 |
+
if module == 'collections' and name == 'OrderedDict':
|
41 |
+
return getattr(collections, name)
|
42 |
+
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
43 |
+
return getattr(torch._utils, name)
|
44 |
+
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
|
45 |
+
return getattr(torch, name)
|
46 |
+
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
47 |
+
return getattr(torch.nn.modules.container, name)
|
48 |
+
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
|
49 |
+
return getattr(numpy.core.multiarray, name)
|
50 |
+
if module == 'numpy' and name in ['dtype', 'ndarray']:
|
51 |
+
return getattr(numpy, name)
|
52 |
+
if module == '_codecs' and name == 'encode':
|
53 |
+
return encode
|
54 |
+
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
55 |
+
import pytorch_lightning.callbacks
|
56 |
+
return pytorch_lightning.callbacks.model_checkpoint
|
57 |
+
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
58 |
+
import pytorch_lightning.callbacks.model_checkpoint
|
59 |
+
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
60 |
+
if module == "__builtin__" and name == 'set':
|
61 |
+
return set
|
62 |
+
|
63 |
+
# Forbid everything else.
|
64 |
+
raise Exception(f"global '{module}/{name}' is forbidden")
|
65 |
+
|
66 |
+
|
67 |
+
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
|
68 |
+
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
|
69 |
+
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
70 |
+
|
71 |
+
def check_zip_filenames(filename, names):
|
72 |
+
for name in names:
|
73 |
+
if allowed_zip_names_re.match(name):
|
74 |
+
continue
|
75 |
+
|
76 |
+
raise Exception(f"bad file inside {filename}: {name}")
|
77 |
+
|
78 |
+
|
79 |
+
def check_pt(filename, extra_handler):
|
80 |
+
try:
|
81 |
+
|
82 |
+
# new pytorch format is a zip file
|
83 |
+
with zipfile.ZipFile(filename) as z:
|
84 |
+
check_zip_filenames(filename, z.namelist())
|
85 |
+
|
86 |
+
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
87 |
+
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
88 |
+
if len(data_pkl_filenames) == 0:
|
89 |
+
raise Exception(f"data.pkl not found in {filename}")
|
90 |
+
if len(data_pkl_filenames) > 1:
|
91 |
+
raise Exception(f"Multiple data.pkl found in {filename}")
|
92 |
+
with z.open(data_pkl_filenames[0]) as file:
|
93 |
+
unpickler = RestrictedUnpickler(file)
|
94 |
+
unpickler.extra_handler = extra_handler
|
95 |
+
unpickler.load()
|
96 |
+
|
97 |
+
except zipfile.BadZipfile:
|
98 |
+
|
99 |
+
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
100 |
+
with open(filename, "rb") as file:
|
101 |
+
unpickler = RestrictedUnpickler(file)
|
102 |
+
unpickler.extra_handler = extra_handler
|
103 |
+
for _ in range(5):
|
104 |
+
unpickler.load()
|
105 |
+
|
106 |
+
|
107 |
+
def load(filename, *args, **kwargs):
|
108 |
+
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
109 |
+
|
110 |
+
|
111 |
+
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
112 |
+
"""
|
113 |
+
this function is intended to be used by extensions that want to load models with
|
114 |
+
some extra classes in them that the usual unpickler would find suspicious.
|
115 |
+
|
116 |
+
Use the extra_handler argument to specify a function that takes module and field name as text,
|
117 |
+
and returns that field's value:
|
118 |
+
|
119 |
+
```python
|
120 |
+
def extra(module, name):
|
121 |
+
if module == 'collections' and name == 'OrderedDict':
|
122 |
+
return collections.OrderedDict
|
123 |
+
|
124 |
+
return None
|
125 |
+
|
126 |
+
safe.load_with_extra('model.pt', extra_handler=extra)
|
127 |
+
```
|
128 |
+
|
129 |
+
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
130 |
+
definitely unsafe.
|
131 |
+
"""
|
132 |
+
|
133 |
+
from modules import shared
|
134 |
+
|
135 |
+
try:
|
136 |
+
if not shared.cmd_opts.disable_safe_unpickle:
|
137 |
+
check_pt(filename, extra_handler)
|
138 |
+
|
139 |
+
except pickle.UnpicklingError:
|
140 |
+
errors.report(
|
141 |
+
f"Error verifying pickled file from {filename}\n"
|
142 |
+
"-----> !!!! The file is most likely corrupted !!!! <-----\n"
|
143 |
+
"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
|
144 |
+
exc_info=True,
|
145 |
+
)
|
146 |
+
return None
|
147 |
+
except Exception:
|
148 |
+
errors.report(
|
149 |
+
f"Error verifying pickled file from {filename}\n"
|
150 |
+
f"The file may be malicious, so the program is not going to read it.\n"
|
151 |
+
f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
|
152 |
+
exc_info=True,
|
153 |
+
)
|
154 |
+
return None
|
155 |
+
|
156 |
+
return unsafe_torch_load(filename, *args, **kwargs)
|
157 |
+
|
158 |
+
|
159 |
+
class Extra:
|
160 |
+
"""
|
161 |
+
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
162 |
+
(because it's not your code making the torch.load call). The intended use is like this:
|
163 |
+
|
164 |
+
```
|
165 |
+
import torch
|
166 |
+
from modules import safe
|
167 |
+
|
168 |
+
def handler(module, name):
|
169 |
+
if module == 'torch' and name in ['float64', 'float16']:
|
170 |
+
return getattr(torch, name)
|
171 |
+
|
172 |
+
return None
|
173 |
+
|
174 |
+
with safe.Extra(handler):
|
175 |
+
x = torch.load('model.pt')
|
176 |
+
```
|
177 |
+
"""
|
178 |
+
|
179 |
+
def __init__(self, handler):
|
180 |
+
self.handler = handler
|
181 |
+
|
182 |
+
def __enter__(self):
|
183 |
+
global global_extra_handler
|
184 |
+
|
185 |
+
assert global_extra_handler is None, 'already inside an Extra() block'
|
186 |
+
global_extra_handler = self.handler
|
187 |
+
|
188 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
189 |
+
global global_extra_handler
|
190 |
+
|
191 |
+
global_extra_handler = None
|
192 |
+
|
193 |
+
|
194 |
+
unsafe_torch_load = torch.load
|
195 |
+
torch.load = load
|
196 |
+
global_extra_handler = None
|
modules/script_callbacks.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import os
|
3 |
+
from collections import namedtuple
|
4 |
+
from typing import Optional, Dict, Any
|
5 |
+
|
6 |
+
from fastapi import FastAPI
|
7 |
+
from gradio import Blocks
|
8 |
+
|
9 |
+
from modules import errors, timer
|
10 |
+
|
11 |
+
|
12 |
+
def report_exception(c, job):
|
13 |
+
errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)
|
14 |
+
|
15 |
+
|
16 |
+
class ImageSaveParams:
|
17 |
+
def __init__(self, image, p, filename, pnginfo):
|
18 |
+
self.image = image
|
19 |
+
"""the PIL image itself"""
|
20 |
+
|
21 |
+
self.p = p
|
22 |
+
"""p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
|
23 |
+
|
24 |
+
self.filename = filename
|
25 |
+
"""name of file that the image would be saved to"""
|
26 |
+
|
27 |
+
self.pnginfo = pnginfo
|
28 |
+
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
29 |
+
|
30 |
+
|
31 |
+
class CFGDenoiserParams:
|
32 |
+
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
33 |
+
self.x = x
|
34 |
+
"""Latent image representation in the process of being denoised"""
|
35 |
+
|
36 |
+
self.image_cond = image_cond
|
37 |
+
"""Conditioning image"""
|
38 |
+
|
39 |
+
self.sigma = sigma
|
40 |
+
"""Current sigma noise step value"""
|
41 |
+
|
42 |
+
self.sampling_step = sampling_step
|
43 |
+
"""Current Sampling step number"""
|
44 |
+
|
45 |
+
self.total_sampling_steps = total_sampling_steps
|
46 |
+
"""Total number of sampling steps planned"""
|
47 |
+
|
48 |
+
self.text_cond = text_cond
|
49 |
+
""" Encoder hidden states of text conditioning from prompt"""
|
50 |
+
|
51 |
+
self.text_uncond = text_uncond
|
52 |
+
""" Encoder hidden states of text conditioning from negative prompt"""
|
53 |
+
|
54 |
+
|
55 |
+
class CFGDenoisedParams:
|
56 |
+
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
|
57 |
+
self.x = x
|
58 |
+
"""Latent image representation in the process of being denoised"""
|
59 |
+
|
60 |
+
self.sampling_step = sampling_step
|
61 |
+
"""Current Sampling step number"""
|
62 |
+
|
63 |
+
self.total_sampling_steps = total_sampling_steps
|
64 |
+
"""Total number of sampling steps planned"""
|
65 |
+
|
66 |
+
self.inner_model = inner_model
|
67 |
+
"""Inner model reference used for denoising"""
|
68 |
+
|
69 |
+
|
70 |
+
class AfterCFGCallbackParams:
|
71 |
+
def __init__(self, x, sampling_step, total_sampling_steps):
|
72 |
+
self.x = x
|
73 |
+
"""Latent image representation in the process of being denoised"""
|
74 |
+
|
75 |
+
self.sampling_step = sampling_step
|
76 |
+
"""Current Sampling step number"""
|
77 |
+
|
78 |
+
self.total_sampling_steps = total_sampling_steps
|
79 |
+
"""Total number of sampling steps planned"""
|
80 |
+
|
81 |
+
|
82 |
+
class UiTrainTabParams:
|
83 |
+
def __init__(self, txt2img_preview_params):
|
84 |
+
self.txt2img_preview_params = txt2img_preview_params
|
85 |
+
|
86 |
+
|
87 |
+
class ImageGridLoopParams:
|
88 |
+
def __init__(self, imgs, cols, rows):
|
89 |
+
self.imgs = imgs
|
90 |
+
self.cols = cols
|
91 |
+
self.rows = rows
|
92 |
+
|
93 |
+
|
94 |
+
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
95 |
+
callback_map = dict(
|
96 |
+
callbacks_app_started=[],
|
97 |
+
callbacks_model_loaded=[],
|
98 |
+
callbacks_ui_tabs=[],
|
99 |
+
callbacks_ui_train_tabs=[],
|
100 |
+
callbacks_ui_settings=[],
|
101 |
+
callbacks_before_image_saved=[],
|
102 |
+
callbacks_image_saved=[],
|
103 |
+
callbacks_cfg_denoiser=[],
|
104 |
+
callbacks_cfg_denoised=[],
|
105 |
+
callbacks_cfg_after_cfg=[],
|
106 |
+
callbacks_before_component=[],
|
107 |
+
callbacks_after_component=[],
|
108 |
+
callbacks_image_grid=[],
|
109 |
+
callbacks_infotext_pasted=[],
|
110 |
+
callbacks_script_unloaded=[],
|
111 |
+
callbacks_before_ui=[],
|
112 |
+
callbacks_on_reload=[],
|
113 |
+
callbacks_list_optimizers=[],
|
114 |
+
callbacks_list_unets=[],
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
def clear_callbacks():
|
119 |
+
for callback_list in callback_map.values():
|
120 |
+
callback_list.clear()
|
121 |
+
|
122 |
+
|
123 |
+
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
124 |
+
for c in callback_map['callbacks_app_started']:
|
125 |
+
try:
|
126 |
+
c.callback(demo, app)
|
127 |
+
timer.startup_timer.record(os.path.basename(c.script))
|
128 |
+
except Exception:
|
129 |
+
report_exception(c, 'app_started_callback')
|
130 |
+
|
131 |
+
|
132 |
+
def app_reload_callback():
|
133 |
+
for c in callback_map['callbacks_on_reload']:
|
134 |
+
try:
|
135 |
+
c.callback()
|
136 |
+
except Exception:
|
137 |
+
report_exception(c, 'callbacks_on_reload')
|
138 |
+
|
139 |
+
|
140 |
+
def model_loaded_callback(sd_model):
|
141 |
+
for c in callback_map['callbacks_model_loaded']:
|
142 |
+
try:
|
143 |
+
c.callback(sd_model)
|
144 |
+
except Exception:
|
145 |
+
report_exception(c, 'model_loaded_callback')
|
146 |
+
|
147 |
+
|
148 |
+
def ui_tabs_callback():
|
149 |
+
res = []
|
150 |
+
|
151 |
+
for c in callback_map['callbacks_ui_tabs']:
|
152 |
+
try:
|
153 |
+
res += c.callback() or []
|
154 |
+
except Exception:
|
155 |
+
report_exception(c, 'ui_tabs_callback')
|
156 |
+
|
157 |
+
return res
|
158 |
+
|
159 |
+
|
160 |
+
def ui_train_tabs_callback(params: UiTrainTabParams):
|
161 |
+
for c in callback_map['callbacks_ui_train_tabs']:
|
162 |
+
try:
|
163 |
+
c.callback(params)
|
164 |
+
except Exception:
|
165 |
+
report_exception(c, 'callbacks_ui_train_tabs')
|
166 |
+
|
167 |
+
|
168 |
+
def ui_settings_callback():
|
169 |
+
for c in callback_map['callbacks_ui_settings']:
|
170 |
+
try:
|
171 |
+
c.callback()
|
172 |
+
except Exception:
|
173 |
+
report_exception(c, 'ui_settings_callback')
|
174 |
+
|
175 |
+
|
176 |
+
def before_image_saved_callback(params: ImageSaveParams):
|
177 |
+
for c in callback_map['callbacks_before_image_saved']:
|
178 |
+
try:
|
179 |
+
c.callback(params)
|
180 |
+
except Exception:
|
181 |
+
report_exception(c, 'before_image_saved_callback')
|
182 |
+
|
183 |
+
|
184 |
+
def image_saved_callback(params: ImageSaveParams):
|
185 |
+
for c in callback_map['callbacks_image_saved']:
|
186 |
+
try:
|
187 |
+
c.callback(params)
|
188 |
+
except Exception:
|
189 |
+
report_exception(c, 'image_saved_callback')
|
190 |
+
|
191 |
+
|
192 |
+
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
193 |
+
for c in callback_map['callbacks_cfg_denoiser']:
|
194 |
+
try:
|
195 |
+
c.callback(params)
|
196 |
+
except Exception:
|
197 |
+
report_exception(c, 'cfg_denoiser_callback')
|
198 |
+
|
199 |
+
|
200 |
+
def cfg_denoised_callback(params: CFGDenoisedParams):
|
201 |
+
for c in callback_map['callbacks_cfg_denoised']:
|
202 |
+
try:
|
203 |
+
c.callback(params)
|
204 |
+
except Exception:
|
205 |
+
report_exception(c, 'cfg_denoised_callback')
|
206 |
+
|
207 |
+
|
208 |
+
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
|
209 |
+
for c in callback_map['callbacks_cfg_after_cfg']:
|
210 |
+
try:
|
211 |
+
c.callback(params)
|
212 |
+
except Exception:
|
213 |
+
report_exception(c, 'cfg_after_cfg_callback')
|
214 |
+
|
215 |
+
|
216 |
+
def before_component_callback(component, **kwargs):
|
217 |
+
for c in callback_map['callbacks_before_component']:
|
218 |
+
try:
|
219 |
+
c.callback(component, **kwargs)
|
220 |
+
except Exception:
|
221 |
+
report_exception(c, 'before_component_callback')
|
222 |
+
|
223 |
+
|
224 |
+
def after_component_callback(component, **kwargs):
|
225 |
+
for c in callback_map['callbacks_after_component']:
|
226 |
+
try:
|
227 |
+
c.callback(component, **kwargs)
|
228 |
+
except Exception:
|
229 |
+
report_exception(c, 'after_component_callback')
|
230 |
+
|
231 |
+
|
232 |
+
def image_grid_callback(params: ImageGridLoopParams):
|
233 |
+
for c in callback_map['callbacks_image_grid']:
|
234 |
+
try:
|
235 |
+
c.callback(params)
|
236 |
+
except Exception:
|
237 |
+
report_exception(c, 'image_grid')
|
238 |
+
|
239 |
+
|
240 |
+
def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
|
241 |
+
for c in callback_map['callbacks_infotext_pasted']:
|
242 |
+
try:
|
243 |
+
c.callback(infotext, params)
|
244 |
+
except Exception:
|
245 |
+
report_exception(c, 'infotext_pasted')
|
246 |
+
|
247 |
+
|
248 |
+
def script_unloaded_callback():
|
249 |
+
for c in reversed(callback_map['callbacks_script_unloaded']):
|
250 |
+
try:
|
251 |
+
c.callback()
|
252 |
+
except Exception:
|
253 |
+
report_exception(c, 'script_unloaded')
|
254 |
+
|
255 |
+
|
256 |
+
def before_ui_callback():
|
257 |
+
for c in reversed(callback_map['callbacks_before_ui']):
|
258 |
+
try:
|
259 |
+
c.callback()
|
260 |
+
except Exception:
|
261 |
+
report_exception(c, 'before_ui')
|
262 |
+
|
263 |
+
|
264 |
+
def list_optimizers_callback():
|
265 |
+
res = []
|
266 |
+
|
267 |
+
for c in callback_map['callbacks_list_optimizers']:
|
268 |
+
try:
|
269 |
+
c.callback(res)
|
270 |
+
except Exception:
|
271 |
+
report_exception(c, 'list_optimizers')
|
272 |
+
|
273 |
+
return res
|
274 |
+
|
275 |
+
|
276 |
+
def list_unets_callback():
|
277 |
+
res = []
|
278 |
+
|
279 |
+
for c in callback_map['callbacks_list_unets']:
|
280 |
+
try:
|
281 |
+
c.callback(res)
|
282 |
+
except Exception:
|
283 |
+
report_exception(c, 'list_unets')
|
284 |
+
|
285 |
+
return res
|
286 |
+
|
287 |
+
|
288 |
+
def add_callback(callbacks, fun):
|
289 |
+
stack = [x for x in inspect.stack() if x.filename != __file__]
|
290 |
+
filename = stack[0].filename if stack else 'unknown file'
|
291 |
+
|
292 |
+
callbacks.append(ScriptCallback(filename, fun))
|
293 |
+
|
294 |
+
|
295 |
+
def remove_current_script_callbacks():
|
296 |
+
stack = [x for x in inspect.stack() if x.filename != __file__]
|
297 |
+
filename = stack[0].filename if stack else 'unknown file'
|
298 |
+
if filename == 'unknown file':
|
299 |
+
return
|
300 |
+
for callback_list in callback_map.values():
|
301 |
+
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
|
302 |
+
callback_list.remove(callback_to_remove)
|
303 |
+
|
304 |
+
|
305 |
+
def remove_callbacks_for_function(callback_func):
|
306 |
+
for callback_list in callback_map.values():
|
307 |
+
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
|
308 |
+
callback_list.remove(callback_to_remove)
|
309 |
+
|
310 |
+
|
311 |
+
def on_app_started(callback):
|
312 |
+
"""register a function to be called when the webui started, the gradio `Block` component and
|
313 |
+
fastapi `FastAPI` object are passed as the arguments"""
|
314 |
+
add_callback(callback_map['callbacks_app_started'], callback)
|
315 |
+
|
316 |
+
|
317 |
+
def on_before_reload(callback):
|
318 |
+
"""register a function to be called just before the server reloads."""
|
319 |
+
add_callback(callback_map['callbacks_on_reload'], callback)
|
320 |
+
|
321 |
+
|
322 |
+
def on_model_loaded(callback):
|
323 |
+
"""register a function to be called when the stable diffusion model is created; the model is
|
324 |
+
passed as an argument; this function is also called when the script is reloaded. """
|
325 |
+
add_callback(callback_map['callbacks_model_loaded'], callback)
|
326 |
+
|
327 |
+
|
328 |
+
def on_ui_tabs(callback):
|
329 |
+
"""register a function to be called when the UI is creating new tabs.
|
330 |
+
The function must either return a None, which means no new tabs to be added, or a list, where
|
331 |
+
each element is a tuple:
|
332 |
+
(gradio_component, title, elem_id)
|
333 |
+
|
334 |
+
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
|
335 |
+
title is tab text displayed to user in the UI
|
336 |
+
elem_id is HTML id for the tab
|
337 |
+
"""
|
338 |
+
add_callback(callback_map['callbacks_ui_tabs'], callback)
|
339 |
+
|
340 |
+
|
341 |
+
def on_ui_train_tabs(callback):
|
342 |
+
"""register a function to be called when the UI is creating new tabs for the train tab.
|
343 |
+
Create your new tabs with gr.Tab.
|
344 |
+
"""
|
345 |
+
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
|
346 |
+
|
347 |
+
|
348 |
+
def on_ui_settings(callback):
|
349 |
+
"""register a function to be called before UI settings are populated; add your settings
|
350 |
+
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
351 |
+
add_callback(callback_map['callbacks_ui_settings'], callback)
|
352 |
+
|
353 |
+
|
354 |
+
def on_before_image_saved(callback):
|
355 |
+
"""register a function to be called before an image is saved to a file.
|
356 |
+
The callback is called with one argument:
|
357 |
+
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
358 |
+
"""
|
359 |
+
add_callback(callback_map['callbacks_before_image_saved'], callback)
|
360 |
+
|
361 |
+
|
362 |
+
def on_image_saved(callback):
|
363 |
+
"""register a function to be called after an image is saved to a file.
|
364 |
+
The callback is called with one argument:
|
365 |
+
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
366 |
+
"""
|
367 |
+
add_callback(callback_map['callbacks_image_saved'], callback)
|
368 |
+
|
369 |
+
|
370 |
+
def on_cfg_denoiser(callback):
|
371 |
+
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
372 |
+
The callback is called with one argument:
|
373 |
+
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
|
374 |
+
"""
|
375 |
+
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
|
376 |
+
|
377 |
+
|
378 |
+
def on_cfg_denoised(callback):
|
379 |
+
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
380 |
+
The callback is called with one argument:
|
381 |
+
- params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
|
382 |
+
"""
|
383 |
+
add_callback(callback_map['callbacks_cfg_denoised'], callback)
|
384 |
+
|
385 |
+
|
386 |
+
def on_cfg_after_cfg(callback):
|
387 |
+
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
|
388 |
+
The callback is called with one argument:
|
389 |
+
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
|
390 |
+
"""
|
391 |
+
add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
|
392 |
+
|
393 |
+
|
394 |
+
def on_before_component(callback):
|
395 |
+
"""register a function to be called before a component is created.
|
396 |
+
The callback is called with arguments:
|
397 |
+
- component - gradio component that is about to be created.
|
398 |
+
- **kwargs - args to gradio.components.IOComponent.__init__ function
|
399 |
+
|
400 |
+
Use elem_id/label fields of kwargs to figure out which component it is.
|
401 |
+
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
402 |
+
"""
|
403 |
+
add_callback(callback_map['callbacks_before_component'], callback)
|
404 |
+
|
405 |
+
|
406 |
+
def on_after_component(callback):
|
407 |
+
"""register a function to be called after a component is created. See on_before_component for more."""
|
408 |
+
add_callback(callback_map['callbacks_after_component'], callback)
|
409 |
+
|
410 |
+
|
411 |
+
def on_image_grid(callback):
|
412 |
+
"""register a function to be called before making an image grid.
|
413 |
+
The callback is called with one argument:
|
414 |
+
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
|
415 |
+
"""
|
416 |
+
add_callback(callback_map['callbacks_image_grid'], callback)
|
417 |
+
|
418 |
+
|
419 |
+
def on_infotext_pasted(callback):
|
420 |
+
"""register a function to be called before applying an infotext.
|
421 |
+
The callback is called with two arguments:
|
422 |
+
- infotext: str - raw infotext.
|
423 |
+
- result: Dict[str, any] - parsed infotext parameters.
|
424 |
+
"""
|
425 |
+
add_callback(callback_map['callbacks_infotext_pasted'], callback)
|
426 |
+
|
427 |
+
|
428 |
+
def on_script_unloaded(callback):
|
429 |
+
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
|
430 |
+
the script did should be reverted here"""
|
431 |
+
|
432 |
+
add_callback(callback_map['callbacks_script_unloaded'], callback)
|
433 |
+
|
434 |
+
|
435 |
+
def on_before_ui(callback):
|
436 |
+
"""register a function to be called before the UI is created."""
|
437 |
+
|
438 |
+
add_callback(callback_map['callbacks_before_ui'], callback)
|
439 |
+
|
440 |
+
|
441 |
+
def on_list_optimizers(callback):
|
442 |
+
"""register a function to be called when UI is making a list of cross attention optimization options.
|
443 |
+
The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
|
444 |
+
to it."""
|
445 |
+
|
446 |
+
add_callback(callback_map['callbacks_list_optimizers'], callback)
|
447 |
+
|
448 |
+
|
449 |
+
def on_list_unets(callback):
|
450 |
+
"""register a function to be called when UI is making a list of alternative options for unet.
|
451 |
+
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
|
452 |
+
|
453 |
+
add_callback(callback_map['callbacks_list_unets'], callback)
|
modules/script_loading.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import importlib.util
|
3 |
+
|
4 |
+
from modules import errors
|
5 |
+
|
6 |
+
|
7 |
+
def load_module(path):
|
8 |
+
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
|
9 |
+
module = importlib.util.module_from_spec(module_spec)
|
10 |
+
module_spec.loader.exec_module(module)
|
11 |
+
|
12 |
+
return module
|
13 |
+
|
14 |
+
|
15 |
+
def preload_extensions(extensions_dir, parser, extension_list=None):
|
16 |
+
if not os.path.isdir(extensions_dir):
|
17 |
+
return
|
18 |
+
|
19 |
+
extensions = extension_list if extension_list is not None else os.listdir(extensions_dir)
|
20 |
+
for dirname in sorted(extensions):
|
21 |
+
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
|
22 |
+
if not os.path.isfile(preload_script):
|
23 |
+
continue
|
24 |
+
|
25 |
+
try:
|
26 |
+
module = load_module(preload_script)
|
27 |
+
if hasattr(module, 'preload'):
|
28 |
+
module.preload(parser)
|
29 |
+
|
30 |
+
except Exception:
|
31 |
+
errors.report(f"Error running preload() for {preload_script}", exc_info=True)
|
modules/scripts.py
ADDED
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import sys
|
4 |
+
import inspect
|
5 |
+
from collections import namedtuple
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
|
10 |
+
|
11 |
+
AlwaysVisible = object()
|
12 |
+
|
13 |
+
|
14 |
+
class PostprocessImageArgs:
|
15 |
+
def __init__(self, image):
|
16 |
+
self.image = image
|
17 |
+
|
18 |
+
|
19 |
+
class PostprocessBatchListArgs:
|
20 |
+
def __init__(self, images):
|
21 |
+
self.images = images
|
22 |
+
|
23 |
+
|
24 |
+
class Script:
|
25 |
+
name = None
|
26 |
+
"""script's internal name derived from title"""
|
27 |
+
|
28 |
+
section = None
|
29 |
+
"""name of UI section that the script's controls will be placed into"""
|
30 |
+
|
31 |
+
filename = None
|
32 |
+
args_from = None
|
33 |
+
args_to = None
|
34 |
+
alwayson = False
|
35 |
+
|
36 |
+
is_txt2img = False
|
37 |
+
is_img2img = False
|
38 |
+
|
39 |
+
group = None
|
40 |
+
"""A gr.Group component that has all script's UI inside it"""
|
41 |
+
|
42 |
+
infotext_fields = None
|
43 |
+
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
44 |
+
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
|
45 |
+
"""
|
46 |
+
|
47 |
+
paste_field_names = None
|
48 |
+
"""if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
|
49 |
+
various "Send to <X>" buttons when clicked
|
50 |
+
"""
|
51 |
+
|
52 |
+
api_info = None
|
53 |
+
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
|
54 |
+
|
55 |
+
def title(self):
|
56 |
+
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
57 |
+
|
58 |
+
raise NotImplementedError()
|
59 |
+
|
60 |
+
def ui(self, is_img2img):
|
61 |
+
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
|
62 |
+
The return value should be an array of all components that are used in processing.
|
63 |
+
Values of those returned components will be passed to run() and process() functions.
|
64 |
+
"""
|
65 |
+
|
66 |
+
pass
|
67 |
+
|
68 |
+
def show(self, is_img2img):
|
69 |
+
"""
|
70 |
+
is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
|
71 |
+
|
72 |
+
This function should return:
|
73 |
+
- False if the script should not be shown in UI at all
|
74 |
+
- True if the script should be shown in UI if it's selected in the scripts dropdown
|
75 |
+
- script.AlwaysVisible if the script should be shown in UI at all times
|
76 |
+
"""
|
77 |
+
|
78 |
+
return True
|
79 |
+
|
80 |
+
def run(self, p, *args):
|
81 |
+
"""
|
82 |
+
This function is called if the script has been selected in the script dropdown.
|
83 |
+
It must do all processing and return the Processed object with results, same as
|
84 |
+
one returned by processing.process_images.
|
85 |
+
|
86 |
+
Usually the processing is done by calling the processing.process_images function.
|
87 |
+
|
88 |
+
args contains all values returned by components from ui()
|
89 |
+
"""
|
90 |
+
|
91 |
+
pass
|
92 |
+
|
93 |
+
def before_process(self, p, *args):
|
94 |
+
"""
|
95 |
+
This function is called very early before processing begins for AlwaysVisible scripts.
|
96 |
+
You can modify the processing object (p) here, inject hooks, etc.
|
97 |
+
args contains all values returned by components from ui()
|
98 |
+
"""
|
99 |
+
|
100 |
+
pass
|
101 |
+
|
102 |
+
def process(self, p, *args):
|
103 |
+
"""
|
104 |
+
This function is called before processing begins for AlwaysVisible scripts.
|
105 |
+
You can modify the processing object (p) here, inject hooks, etc.
|
106 |
+
args contains all values returned by components from ui()
|
107 |
+
"""
|
108 |
+
|
109 |
+
pass
|
110 |
+
|
111 |
+
def before_process_batch(self, p, *args, **kwargs):
|
112 |
+
"""
|
113 |
+
Called before extra networks are parsed from the prompt, so you can add
|
114 |
+
new extra network keywords to the prompt with this callback.
|
115 |
+
|
116 |
+
**kwargs will have those items:
|
117 |
+
- batch_number - index of current batch, from 0 to number of batches-1
|
118 |
+
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
119 |
+
- seeds - list of seeds for current batch
|
120 |
+
- subseeds - list of subseeds for current batch
|
121 |
+
"""
|
122 |
+
|
123 |
+
pass
|
124 |
+
|
125 |
+
def after_extra_networks_activate(self, p, *args, **kwargs):
|
126 |
+
"""
|
127 |
+
Called after extra networks activation, before conds calculation
|
128 |
+
allow modification of the network after extra networks activation been applied
|
129 |
+
won't be call if p.disable_extra_networks
|
130 |
+
|
131 |
+
**kwargs will have those items:
|
132 |
+
- batch_number - index of current batch, from 0 to number of batches-1
|
133 |
+
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
134 |
+
- seeds - list of seeds for current batch
|
135 |
+
- subseeds - list of subseeds for current batch
|
136 |
+
- extra_network_data - list of ExtraNetworkParams for current stage
|
137 |
+
"""
|
138 |
+
pass
|
139 |
+
|
140 |
+
def process_batch(self, p, *args, **kwargs):
|
141 |
+
"""
|
142 |
+
Same as process(), but called for every batch.
|
143 |
+
|
144 |
+
**kwargs will have those items:
|
145 |
+
- batch_number - index of current batch, from 0 to number of batches-1
|
146 |
+
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
147 |
+
- seeds - list of seeds for current batch
|
148 |
+
- subseeds - list of subseeds for current batch
|
149 |
+
"""
|
150 |
+
|
151 |
+
pass
|
152 |
+
|
153 |
+
def postprocess_batch(self, p, *args, **kwargs):
|
154 |
+
"""
|
155 |
+
Same as process_batch(), but called for every batch after it has been generated.
|
156 |
+
|
157 |
+
**kwargs will have same items as process_batch, and also:
|
158 |
+
- batch_number - index of current batch, from 0 to number of batches-1
|
159 |
+
- images - torch tensor with all generated images, with values ranging from 0 to 1;
|
160 |
+
"""
|
161 |
+
|
162 |
+
pass
|
163 |
+
|
164 |
+
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
|
165 |
+
"""
|
166 |
+
Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
|
167 |
+
This is useful when you want to update the entire batch instead of individual images.
|
168 |
+
|
169 |
+
You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
|
170 |
+
If the number of images is different from the batch size when returning,
|
171 |
+
then the script has the responsibility to also update the following attributes in the processing object (p):
|
172 |
+
- p.prompts
|
173 |
+
- p.negative_prompts
|
174 |
+
- p.seeds
|
175 |
+
- p.subseeds
|
176 |
+
|
177 |
+
**kwargs will have same items as process_batch, and also:
|
178 |
+
- batch_number - index of current batch, from 0 to number of batches-1
|
179 |
+
"""
|
180 |
+
|
181 |
+
pass
|
182 |
+
|
183 |
+
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
184 |
+
"""
|
185 |
+
Called for every image after it has been generated.
|
186 |
+
"""
|
187 |
+
|
188 |
+
pass
|
189 |
+
|
190 |
+
def postprocess(self, p, processed, *args):
|
191 |
+
"""
|
192 |
+
This function is called after processing ends for AlwaysVisible scripts.
|
193 |
+
args contains all values returned by components from ui()
|
194 |
+
"""
|
195 |
+
|
196 |
+
pass
|
197 |
+
|
198 |
+
def before_component(self, component, **kwargs):
|
199 |
+
"""
|
200 |
+
Called before a component is created.
|
201 |
+
Use elem_id/label fields of kwargs to figure out which component it is.
|
202 |
+
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
203 |
+
You can return created components in the ui() function to add them to the list of arguments for your processing functions
|
204 |
+
"""
|
205 |
+
|
206 |
+
pass
|
207 |
+
|
208 |
+
def after_component(self, component, **kwargs):
|
209 |
+
"""
|
210 |
+
Called after a component is created. Same as above.
|
211 |
+
"""
|
212 |
+
|
213 |
+
pass
|
214 |
+
|
215 |
+
def describe(self):
|
216 |
+
"""unused"""
|
217 |
+
return ""
|
218 |
+
|
219 |
+
def elem_id(self, item_id):
|
220 |
+
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
221 |
+
|
222 |
+
need_tabname = self.show(True) == self.show(False)
|
223 |
+
tabkind = 'img2img' if self.is_img2img else 'txt2txt'
|
224 |
+
tabname = f"{tabkind}_" if need_tabname else ""
|
225 |
+
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
226 |
+
|
227 |
+
return f'script_{tabname}{title}_{item_id}'
|
228 |
+
|
229 |
+
def before_hr(self, p, *args):
|
230 |
+
"""
|
231 |
+
This function is called before hires fix start.
|
232 |
+
"""
|
233 |
+
pass
|
234 |
+
|
235 |
+
current_basedir = paths.script_path
|
236 |
+
|
237 |
+
|
238 |
+
def basedir():
|
239 |
+
"""returns the base directory for the current script. For scripts in the main scripts directory,
|
240 |
+
this is the main directory (where webui.py resides), and for scripts in extensions directory
|
241 |
+
(ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
|
242 |
+
"""
|
243 |
+
return current_basedir
|
244 |
+
|
245 |
+
|
246 |
+
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
247 |
+
|
248 |
+
scripts_data = []
|
249 |
+
postprocessing_scripts_data = []
|
250 |
+
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
251 |
+
|
252 |
+
|
253 |
+
def list_scripts(scriptdirname, extension):
|
254 |
+
scripts_list = []
|
255 |
+
|
256 |
+
basedir = os.path.join(paths.script_path, scriptdirname)
|
257 |
+
if os.path.exists(basedir):
|
258 |
+
for filename in sorted(os.listdir(basedir)):
|
259 |
+
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
260 |
+
|
261 |
+
for ext in extensions.active():
|
262 |
+
scripts_list += ext.list_files(scriptdirname, extension)
|
263 |
+
|
264 |
+
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
265 |
+
|
266 |
+
return scripts_list
|
267 |
+
|
268 |
+
|
269 |
+
def list_files_with_name(filename):
|
270 |
+
res = []
|
271 |
+
|
272 |
+
dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
|
273 |
+
|
274 |
+
for dirpath in dirs:
|
275 |
+
if not os.path.isdir(dirpath):
|
276 |
+
continue
|
277 |
+
|
278 |
+
path = os.path.join(dirpath, filename)
|
279 |
+
if os.path.isfile(path):
|
280 |
+
res.append(path)
|
281 |
+
|
282 |
+
return res
|
283 |
+
|
284 |
+
|
285 |
+
def load_scripts():
|
286 |
+
global current_basedir
|
287 |
+
scripts_data.clear()
|
288 |
+
postprocessing_scripts_data.clear()
|
289 |
+
script_callbacks.clear_callbacks()
|
290 |
+
|
291 |
+
scripts_list = list_scripts("scripts", ".py")
|
292 |
+
|
293 |
+
syspath = sys.path
|
294 |
+
|
295 |
+
def register_scripts_from_module(module):
|
296 |
+
for script_class in module.__dict__.values():
|
297 |
+
if not inspect.isclass(script_class):
|
298 |
+
continue
|
299 |
+
|
300 |
+
if issubclass(script_class, Script):
|
301 |
+
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
302 |
+
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
303 |
+
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
304 |
+
|
305 |
+
def orderby(basedir):
|
306 |
+
# 1st webui, 2nd extensions-builtin, 3rd extensions
|
307 |
+
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
|
308 |
+
for key in priority:
|
309 |
+
if basedir.startswith(key):
|
310 |
+
return priority[key]
|
311 |
+
return 9999
|
312 |
+
|
313 |
+
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
|
314 |
+
try:
|
315 |
+
if scriptfile.basedir != paths.script_path:
|
316 |
+
sys.path = [scriptfile.basedir] + sys.path
|
317 |
+
current_basedir = scriptfile.basedir
|
318 |
+
|
319 |
+
script_module = script_loading.load_module(scriptfile.path)
|
320 |
+
register_scripts_from_module(script_module)
|
321 |
+
|
322 |
+
except Exception:
|
323 |
+
errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True)
|
324 |
+
|
325 |
+
finally:
|
326 |
+
sys.path = syspath
|
327 |
+
current_basedir = paths.script_path
|
328 |
+
timer.startup_timer.record(scriptfile.filename)
|
329 |
+
|
330 |
+
global scripts_txt2img, scripts_img2img, scripts_postproc
|
331 |
+
|
332 |
+
scripts_txt2img = ScriptRunner()
|
333 |
+
scripts_img2img = ScriptRunner()
|
334 |
+
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
335 |
+
|
336 |
+
|
337 |
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
338 |
+
try:
|
339 |
+
return func(*args, **kwargs)
|
340 |
+
except Exception:
|
341 |
+
errors.report(f"Error calling: {filename}/{funcname}", exc_info=True)
|
342 |
+
|
343 |
+
return default
|
344 |
+
|
345 |
+
|
346 |
+
class ScriptRunner:
|
347 |
+
def __init__(self):
|
348 |
+
self.scripts = []
|
349 |
+
self.selectable_scripts = []
|
350 |
+
self.alwayson_scripts = []
|
351 |
+
self.titles = []
|
352 |
+
self.infotext_fields = []
|
353 |
+
self.paste_field_names = []
|
354 |
+
self.inputs = [None]
|
355 |
+
|
356 |
+
def initialize_scripts(self, is_img2img):
|
357 |
+
from modules import scripts_auto_postprocessing
|
358 |
+
|
359 |
+
self.scripts.clear()
|
360 |
+
self.alwayson_scripts.clear()
|
361 |
+
self.selectable_scripts.clear()
|
362 |
+
|
363 |
+
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
364 |
+
|
365 |
+
for script_data in auto_processing_scripts + scripts_data:
|
366 |
+
script = script_data.script_class()
|
367 |
+
script.filename = script_data.path
|
368 |
+
script.is_txt2img = not is_img2img
|
369 |
+
script.is_img2img = is_img2img
|
370 |
+
|
371 |
+
visibility = script.show(script.is_img2img)
|
372 |
+
|
373 |
+
if visibility == AlwaysVisible:
|
374 |
+
self.scripts.append(script)
|
375 |
+
self.alwayson_scripts.append(script)
|
376 |
+
script.alwayson = True
|
377 |
+
|
378 |
+
elif visibility:
|
379 |
+
self.scripts.append(script)
|
380 |
+
self.selectable_scripts.append(script)
|
381 |
+
|
382 |
+
def create_script_ui(self, script):
|
383 |
+
import modules.api.models as api_models
|
384 |
+
|
385 |
+
script.args_from = len(self.inputs)
|
386 |
+
script.args_to = len(self.inputs)
|
387 |
+
|
388 |
+
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
389 |
+
|
390 |
+
if controls is None:
|
391 |
+
return
|
392 |
+
|
393 |
+
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
|
394 |
+
api_args = []
|
395 |
+
|
396 |
+
for control in controls:
|
397 |
+
control.custom_script_source = os.path.basename(script.filename)
|
398 |
+
|
399 |
+
arg_info = api_models.ScriptArg(label=control.label or "")
|
400 |
+
|
401 |
+
for field in ("value", "minimum", "maximum", "step", "choices"):
|
402 |
+
v = getattr(control, field, None)
|
403 |
+
if v is not None:
|
404 |
+
setattr(arg_info, field, v)
|
405 |
+
|
406 |
+
api_args.append(arg_info)
|
407 |
+
|
408 |
+
script.api_info = api_models.ScriptInfo(
|
409 |
+
name=script.name,
|
410 |
+
is_img2img=script.is_img2img,
|
411 |
+
is_alwayson=script.alwayson,
|
412 |
+
args=api_args,
|
413 |
+
)
|
414 |
+
|
415 |
+
if script.infotext_fields is not None:
|
416 |
+
self.infotext_fields += script.infotext_fields
|
417 |
+
|
418 |
+
if script.paste_field_names is not None:
|
419 |
+
self.paste_field_names += script.paste_field_names
|
420 |
+
|
421 |
+
self.inputs += controls
|
422 |
+
script.args_to = len(self.inputs)
|
423 |
+
|
424 |
+
def setup_ui_for_section(self, section, scriptlist=None):
|
425 |
+
if scriptlist is None:
|
426 |
+
scriptlist = self.alwayson_scripts
|
427 |
+
|
428 |
+
for script in scriptlist:
|
429 |
+
if script.alwayson and script.section != section:
|
430 |
+
continue
|
431 |
+
|
432 |
+
with gr.Group(visible=script.alwayson) as group:
|
433 |
+
self.create_script_ui(script)
|
434 |
+
|
435 |
+
script.group = group
|
436 |
+
|
437 |
+
def prepare_ui(self):
|
438 |
+
self.inputs = [None]
|
439 |
+
|
440 |
+
def setup_ui(self):
|
441 |
+
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
442 |
+
|
443 |
+
self.setup_ui_for_section(None)
|
444 |
+
|
445 |
+
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
|
446 |
+
self.inputs[0] = dropdown
|
447 |
+
|
448 |
+
self.setup_ui_for_section(None, self.selectable_scripts)
|
449 |
+
|
450 |
+
|
451 |
+
def select_script(script_index):
|
452 |
+
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
|
453 |
+
|
454 |
+
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
|
455 |
+
|
456 |
+
def init_field(title):
|
457 |
+
"""called when an initial value is set from ui-config.json to show script's UI components"""
|
458 |
+
|
459 |
+
if title == 'None':
|
460 |
+
return
|
461 |
+
|
462 |
+
script_index = self.titles.index(title)
|
463 |
+
self.selectable_scripts[script_index].group.visible = True
|
464 |
+
|
465 |
+
dropdown.init_field = init_field
|
466 |
+
|
467 |
+
dropdown.change(
|
468 |
+
fn=select_script,
|
469 |
+
inputs=[dropdown],
|
470 |
+
outputs=[script.group for script in self.selectable_scripts]
|
471 |
+
)
|
472 |
+
|
473 |
+
self.script_load_ctr = 0
|
474 |
+
|
475 |
+
def onload_script_visibility(params):
|
476 |
+
title = params.get('Script', None)
|
477 |
+
if title:
|
478 |
+
title_index = self.titles.index(title)
|
479 |
+
visibility = title_index == self.script_load_ctr
|
480 |
+
self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
|
481 |
+
return gr.update(visible=visibility)
|
482 |
+
else:
|
483 |
+
return gr.update(visible=False)
|
484 |
+
|
485 |
+
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
|
486 |
+
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
|
487 |
+
|
488 |
+
return self.inputs
|
489 |
+
|
490 |
+
def run(self, p, *args):
|
491 |
+
script_index = args[0]
|
492 |
+
|
493 |
+
if script_index == 0:
|
494 |
+
return None
|
495 |
+
|
496 |
+
script = self.selectable_scripts[script_index-1]
|
497 |
+
|
498 |
+
if script is None:
|
499 |
+
return None
|
500 |
+
|
501 |
+
script_args = args[script.args_from:script.args_to]
|
502 |
+
processed = script.run(p, *script_args)
|
503 |
+
|
504 |
+
shared.total_tqdm.clear()
|
505 |
+
|
506 |
+
return processed
|
507 |
+
|
508 |
+
def before_process(self, p):
|
509 |
+
for script in self.alwayson_scripts:
|
510 |
+
try:
|
511 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
512 |
+
script.before_process(p, *script_args)
|
513 |
+
except Exception:
|
514 |
+
errors.report(f"Error running before_process: {script.filename}", exc_info=True)
|
515 |
+
|
516 |
+
def process(self, p):
|
517 |
+
for script in self.alwayson_scripts:
|
518 |
+
try:
|
519 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
520 |
+
script.process(p, *script_args)
|
521 |
+
except Exception:
|
522 |
+
errors.report(f"Error running process: {script.filename}", exc_info=True)
|
523 |
+
|
524 |
+
def before_process_batch(self, p, **kwargs):
|
525 |
+
for script in self.alwayson_scripts:
|
526 |
+
try:
|
527 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
528 |
+
script.before_process_batch(p, *script_args, **kwargs)
|
529 |
+
except Exception:
|
530 |
+
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
|
531 |
+
|
532 |
+
def after_extra_networks_activate(self, p, **kwargs):
|
533 |
+
for script in self.alwayson_scripts:
|
534 |
+
try:
|
535 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
536 |
+
script.after_extra_networks_activate(p, *script_args, **kwargs)
|
537 |
+
except Exception:
|
538 |
+
errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
|
539 |
+
|
540 |
+
def process_batch(self, p, **kwargs):
|
541 |
+
for script in self.alwayson_scripts:
|
542 |
+
try:
|
543 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
544 |
+
script.process_batch(p, *script_args, **kwargs)
|
545 |
+
except Exception:
|
546 |
+
errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
|
547 |
+
|
548 |
+
def postprocess(self, p, processed):
|
549 |
+
for script in self.alwayson_scripts:
|
550 |
+
try:
|
551 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
552 |
+
script.postprocess(p, processed, *script_args)
|
553 |
+
except Exception:
|
554 |
+
errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
|
555 |
+
|
556 |
+
def postprocess_batch(self, p, images, **kwargs):
|
557 |
+
for script in self.alwayson_scripts:
|
558 |
+
try:
|
559 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
560 |
+
script.postprocess_batch(p, *script_args, images=images, **kwargs)
|
561 |
+
except Exception:
|
562 |
+
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
|
563 |
+
|
564 |
+
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
|
565 |
+
for script in self.alwayson_scripts:
|
566 |
+
try:
|
567 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
568 |
+
script.postprocess_batch_list(p, pp, *script_args, **kwargs)
|
569 |
+
except Exception:
|
570 |
+
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
|
571 |
+
|
572 |
+
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
573 |
+
for script in self.alwayson_scripts:
|
574 |
+
try:
|
575 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
576 |
+
script.postprocess_image(p, pp, *script_args)
|
577 |
+
except Exception:
|
578 |
+
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
579 |
+
|
580 |
+
def before_component(self, component, **kwargs):
|
581 |
+
for script in self.scripts:
|
582 |
+
try:
|
583 |
+
script.before_component(component, **kwargs)
|
584 |
+
except Exception:
|
585 |
+
errors.report(f"Error running before_component: {script.filename}", exc_info=True)
|
586 |
+
|
587 |
+
def after_component(self, component, **kwargs):
|
588 |
+
for script in self.scripts:
|
589 |
+
try:
|
590 |
+
script.after_component(component, **kwargs)
|
591 |
+
except Exception:
|
592 |
+
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
|
593 |
+
|
594 |
+
def reload_sources(self, cache):
|
595 |
+
for si, script in list(enumerate(self.scripts)):
|
596 |
+
args_from = script.args_from
|
597 |
+
args_to = script.args_to
|
598 |
+
filename = script.filename
|
599 |
+
|
600 |
+
module = cache.get(filename, None)
|
601 |
+
if module is None:
|
602 |
+
module = script_loading.load_module(script.filename)
|
603 |
+
cache[filename] = module
|
604 |
+
|
605 |
+
for script_class in module.__dict__.values():
|
606 |
+
if type(script_class) == type and issubclass(script_class, Script):
|
607 |
+
self.scripts[si] = script_class()
|
608 |
+
self.scripts[si].filename = filename
|
609 |
+
self.scripts[si].args_from = args_from
|
610 |
+
self.scripts[si].args_to = args_to
|
611 |
+
|
612 |
+
|
613 |
+
def before_hr(self, p):
|
614 |
+
for script in self.alwayson_scripts:
|
615 |
+
try:
|
616 |
+
script_args = p.script_args[script.args_from:script.args_to]
|
617 |
+
script.before_hr(p, *script_args)
|
618 |
+
except Exception:
|
619 |
+
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
620 |
+
|
621 |
+
|
622 |
+
scripts_txt2img: ScriptRunner = None
|
623 |
+
scripts_img2img: ScriptRunner = None
|
624 |
+
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
|
625 |
+
scripts_current: ScriptRunner = None
|
626 |
+
|
627 |
+
|
628 |
+
def reload_script_body_only():
|
629 |
+
cache = {}
|
630 |
+
scripts_txt2img.reload_sources(cache)
|
631 |
+
scripts_img2img.reload_sources(cache)
|
632 |
+
|
633 |
+
|
634 |
+
reload_scripts = load_scripts # compatibility alias
|
635 |
+
|
636 |
+
|
637 |
+
def add_classes_to_gradio_component(comp):
|
638 |
+
"""
|
639 |
+
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
640 |
+
"""
|
641 |
+
|
642 |
+
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
643 |
+
|
644 |
+
if getattr(comp, 'multiselect', False):
|
645 |
+
comp.elem_classes.append('multiselect')
|
646 |
+
|
647 |
+
|
648 |
+
|
649 |
+
def IOComponent_init(self, *args, **kwargs):
|
650 |
+
if scripts_current is not None:
|
651 |
+
scripts_current.before_component(self, **kwargs)
|
652 |
+
|
653 |
+
script_callbacks.before_component_callback(self, **kwargs)
|
654 |
+
|
655 |
+
res = original_IOComponent_init(self, *args, **kwargs)
|
656 |
+
|
657 |
+
add_classes_to_gradio_component(self)
|
658 |
+
|
659 |
+
script_callbacks.after_component_callback(self, **kwargs)
|
660 |
+
|
661 |
+
if scripts_current is not None:
|
662 |
+
scripts_current.after_component(self, **kwargs)
|
663 |
+
|
664 |
+
return res
|
665 |
+
|
666 |
+
|
667 |
+
original_IOComponent_init = gr.components.IOComponent.__init__
|
668 |
+
gr.components.IOComponent.__init__ = IOComponent_init
|
669 |
+
|
670 |
+
|
671 |
+
def BlockContext_init(self, *args, **kwargs):
|
672 |
+
res = original_BlockContext_init(self, *args, **kwargs)
|
673 |
+
|
674 |
+
add_classes_to_gradio_component(self)
|
675 |
+
|
676 |
+
return res
|
677 |
+
|
678 |
+
|
679 |
+
original_BlockContext_init = gr.blocks.BlockContext.__init__
|
680 |
+
gr.blocks.BlockContext.__init__ = BlockContext_init
|
modules/scripts_auto_postprocessing.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules import scripts, scripts_postprocessing, shared
|
2 |
+
|
3 |
+
|
4 |
+
class ScriptPostprocessingForMainUI(scripts.Script):
|
5 |
+
def __init__(self, script_postproc):
|
6 |
+
self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
|
7 |
+
self.postprocessing_controls = None
|
8 |
+
|
9 |
+
def title(self):
|
10 |
+
return self.script.name
|
11 |
+
|
12 |
+
def show(self, is_img2img):
|
13 |
+
return scripts.AlwaysVisible
|
14 |
+
|
15 |
+
def ui(self, is_img2img):
|
16 |
+
self.postprocessing_controls = self.script.ui()
|
17 |
+
return self.postprocessing_controls.values()
|
18 |
+
|
19 |
+
def postprocess_image(self, p, script_pp, *args):
|
20 |
+
args_dict = dict(zip(self.postprocessing_controls, args))
|
21 |
+
|
22 |
+
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
23 |
+
pp.info = {}
|
24 |
+
self.script.process(pp, **args_dict)
|
25 |
+
p.extra_generation_params.update(pp.info)
|
26 |
+
script_pp.image = pp.image
|
27 |
+
|
28 |
+
|
29 |
+
def create_auto_preprocessing_script_data():
|
30 |
+
from modules import scripts
|
31 |
+
|
32 |
+
res = []
|
33 |
+
|
34 |
+
for name in shared.opts.postprocessing_enable_in_main_ui:
|
35 |
+
script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
|
36 |
+
if script is None:
|
37 |
+
continue
|
38 |
+
|
39 |
+
constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
|
40 |
+
res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
|
41 |
+
|
42 |
+
return res
|
modules/scripts_postprocessing.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from modules import errors, shared
|
5 |
+
|
6 |
+
|
7 |
+
class PostprocessedImage:
|
8 |
+
def __init__(self, image):
|
9 |
+
self.image = image
|
10 |
+
self.info = {}
|
11 |
+
|
12 |
+
|
13 |
+
class ScriptPostprocessing:
|
14 |
+
filename = None
|
15 |
+
controls = None
|
16 |
+
args_from = None
|
17 |
+
args_to = None
|
18 |
+
|
19 |
+
order = 1000
|
20 |
+
"""scripts will be ordred by this value in postprocessing UI"""
|
21 |
+
|
22 |
+
name = None
|
23 |
+
"""this function should return the title of the script."""
|
24 |
+
|
25 |
+
group = None
|
26 |
+
"""A gr.Group component that has all script's UI inside it"""
|
27 |
+
|
28 |
+
def ui(self):
|
29 |
+
"""
|
30 |
+
This function should create gradio UI elements. See https://gradio.app/docs/#components
|
31 |
+
The return value should be a dictionary that maps parameter names to components used in processing.
|
32 |
+
Values of those components will be passed to process() function.
|
33 |
+
"""
|
34 |
+
|
35 |
+
pass
|
36 |
+
|
37 |
+
def process(self, pp: PostprocessedImage, **args):
|
38 |
+
"""
|
39 |
+
This function is called to postprocess the image.
|
40 |
+
args contains a dictionary with all values returned by components from ui()
|
41 |
+
"""
|
42 |
+
|
43 |
+
pass
|
44 |
+
|
45 |
+
def image_changed(self):
|
46 |
+
pass
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
52 |
+
try:
|
53 |
+
res = func(*args, **kwargs)
|
54 |
+
return res
|
55 |
+
except Exception as e:
|
56 |
+
errors.display(e, f"calling {filename}/{funcname}")
|
57 |
+
|
58 |
+
return default
|
59 |
+
|
60 |
+
|
61 |
+
class ScriptPostprocessingRunner:
|
62 |
+
def __init__(self):
|
63 |
+
self.scripts = None
|
64 |
+
self.ui_created = False
|
65 |
+
|
66 |
+
def initialize_scripts(self, scripts_data):
|
67 |
+
self.scripts = []
|
68 |
+
|
69 |
+
for script_data in scripts_data:
|
70 |
+
script: ScriptPostprocessing = script_data.script_class()
|
71 |
+
script.filename = script_data.path
|
72 |
+
|
73 |
+
if script.name == "Simple Upscale":
|
74 |
+
continue
|
75 |
+
|
76 |
+
self.scripts.append(script)
|
77 |
+
|
78 |
+
def create_script_ui(self, script, inputs):
|
79 |
+
script.args_from = len(inputs)
|
80 |
+
script.args_to = len(inputs)
|
81 |
+
|
82 |
+
script.controls = wrap_call(script.ui, script.filename, "ui")
|
83 |
+
|
84 |
+
for control in script.controls.values():
|
85 |
+
control.custom_script_source = os.path.basename(script.filename)
|
86 |
+
|
87 |
+
inputs += list(script.controls.values())
|
88 |
+
script.args_to = len(inputs)
|
89 |
+
|
90 |
+
def scripts_in_preferred_order(self):
|
91 |
+
if self.scripts is None:
|
92 |
+
import modules.scripts
|
93 |
+
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
|
94 |
+
|
95 |
+
scripts_order = shared.opts.postprocessing_operation_order
|
96 |
+
|
97 |
+
def script_score(name):
|
98 |
+
for i, possible_match in enumerate(scripts_order):
|
99 |
+
if possible_match == name:
|
100 |
+
return i
|
101 |
+
|
102 |
+
return len(self.scripts)
|
103 |
+
|
104 |
+
script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
|
105 |
+
|
106 |
+
return sorted(self.scripts, key=lambda x: script_scores[x.name])
|
107 |
+
|
108 |
+
def setup_ui(self):
|
109 |
+
inputs = []
|
110 |
+
|
111 |
+
for script in self.scripts_in_preferred_order():
|
112 |
+
with gr.Row() as group:
|
113 |
+
self.create_script_ui(script, inputs)
|
114 |
+
|
115 |
+
script.group = group
|
116 |
+
|
117 |
+
self.ui_created = True
|
118 |
+
return inputs
|
119 |
+
|
120 |
+
def run(self, pp: PostprocessedImage, args):
|
121 |
+
for script in self.scripts_in_preferred_order():
|
122 |
+
shared.state.job = script.name
|
123 |
+
|
124 |
+
script_args = args[script.args_from:script.args_to]
|
125 |
+
|
126 |
+
process_args = {}
|
127 |
+
for (name, _component), value in zip(script.controls.items(), script_args):
|
128 |
+
process_args[name] = value
|
129 |
+
|
130 |
+
script.process(pp, **process_args)
|
131 |
+
|
132 |
+
def create_args_for_run(self, scripts_args):
|
133 |
+
if not self.ui_created:
|
134 |
+
with gr.Blocks(analytics_enabled=False):
|
135 |
+
self.setup_ui()
|
136 |
+
|
137 |
+
scripts = self.scripts_in_preferred_order()
|
138 |
+
args = [None] * max([x.args_to for x in scripts])
|
139 |
+
|
140 |
+
for script in scripts:
|
141 |
+
script_args_dict = scripts_args.get(script.name, None)
|
142 |
+
if script_args_dict is not None:
|
143 |
+
|
144 |
+
for i, name in enumerate(script.controls):
|
145 |
+
args[script.args_from + i] = script_args_dict.get(name, None)
|
146 |
+
|
147 |
+
return args
|
148 |
+
|
149 |
+
def image_changed(self):
|
150 |
+
for script in self.scripts_in_preferred_order():
|
151 |
+
script.image_changed()
|
152 |
+
|
modules/sd_disable_initialization.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ldm.modules.encoders.modules
|
2 |
+
import open_clip
|
3 |
+
import torch
|
4 |
+
import transformers.utils.hub
|
5 |
+
|
6 |
+
|
7 |
+
class DisableInitialization:
|
8 |
+
"""
|
9 |
+
When an object of this class enters a `with` block, it starts:
|
10 |
+
- preventing torch's layer initialization functions from working
|
11 |
+
- changes CLIP and OpenCLIP to not download model weights
|
12 |
+
- changes CLIP to not make requests to check if there is a new version of a file you already have
|
13 |
+
|
14 |
+
When it leaves the block, it reverts everything to how it was before.
|
15 |
+
|
16 |
+
Use it like this:
|
17 |
+
```
|
18 |
+
with DisableInitialization():
|
19 |
+
do_things()
|
20 |
+
```
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, disable_clip=True):
|
24 |
+
self.replaced = []
|
25 |
+
self.disable_clip = disable_clip
|
26 |
+
|
27 |
+
def replace(self, obj, field, func):
|
28 |
+
original = getattr(obj, field, None)
|
29 |
+
if original is None:
|
30 |
+
return None
|
31 |
+
|
32 |
+
self.replaced.append((obj, field, original))
|
33 |
+
setattr(obj, field, func)
|
34 |
+
|
35 |
+
return original
|
36 |
+
|
37 |
+
def __enter__(self):
|
38 |
+
def do_nothing(*args, **kwargs):
|
39 |
+
pass
|
40 |
+
|
41 |
+
def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
|
42 |
+
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
|
43 |
+
|
44 |
+
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
45 |
+
res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
46 |
+
res.name_or_path = pretrained_model_name_or_path
|
47 |
+
return res
|
48 |
+
|
49 |
+
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
|
50 |
+
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
|
51 |
+
return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
|
52 |
+
|
53 |
+
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
54 |
+
|
55 |
+
# this file is always 404, prevent making request
|
56 |
+
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
57 |
+
return None
|
58 |
+
|
59 |
+
try:
|
60 |
+
res = original(url, *args, local_files_only=True, **kwargs)
|
61 |
+
if res is None:
|
62 |
+
res = original(url, *args, local_files_only=False, **kwargs)
|
63 |
+
return res
|
64 |
+
except Exception:
|
65 |
+
return original(url, *args, local_files_only=False, **kwargs)
|
66 |
+
|
67 |
+
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
68 |
+
return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
|
69 |
+
|
70 |
+
def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
|
71 |
+
return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
|
72 |
+
|
73 |
+
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
74 |
+
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
75 |
+
|
76 |
+
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
77 |
+
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
78 |
+
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
79 |
+
|
80 |
+
if self.disable_clip:
|
81 |
+
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
82 |
+
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
83 |
+
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
84 |
+
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
85 |
+
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
86 |
+
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
87 |
+
|
88 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
89 |
+
for obj, field, original in self.replaced:
|
90 |
+
setattr(obj, field, original)
|
91 |
+
|
92 |
+
self.replaced.clear()
|
93 |
+
|
modules/sd_hijack.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn.functional import silu
|
3 |
+
from types import MethodType
|
4 |
+
|
5 |
+
import modules.textual_inversion.textual_inversion
|
6 |
+
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
7 |
+
from modules.hypernetworks import hypernetwork
|
8 |
+
from modules.shared import cmd_opts
|
9 |
+
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
10 |
+
|
11 |
+
import ldm.modules.attention
|
12 |
+
import ldm.modules.diffusionmodules.model
|
13 |
+
import ldm.modules.diffusionmodules.openaimodel
|
14 |
+
import ldm.models.diffusion.ddim
|
15 |
+
import ldm.models.diffusion.plms
|
16 |
+
import ldm.modules.encoders.modules
|
17 |
+
|
18 |
+
import sgm.modules.attention
|
19 |
+
import sgm.modules.diffusionmodules.model
|
20 |
+
import sgm.modules.diffusionmodules.openaimodel
|
21 |
+
import sgm.modules.encoders.modules
|
22 |
+
|
23 |
+
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
24 |
+
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
25 |
+
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
26 |
+
|
27 |
+
# new memory efficient cross attention blocks do not support hypernets and we already
|
28 |
+
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
29 |
+
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
30 |
+
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
31 |
+
|
32 |
+
# silence new console spam from SD2
|
33 |
+
ldm.modules.attention.print = lambda *args: None
|
34 |
+
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
35 |
+
|
36 |
+
optimizers = []
|
37 |
+
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
38 |
+
|
39 |
+
|
40 |
+
def list_optimizers():
|
41 |
+
new_optimizers = script_callbacks.list_optimizers_callback()
|
42 |
+
|
43 |
+
new_optimizers = [x for x in new_optimizers if x.is_available()]
|
44 |
+
|
45 |
+
new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
|
46 |
+
|
47 |
+
optimizers.clear()
|
48 |
+
optimizers.extend(new_optimizers)
|
49 |
+
|
50 |
+
|
51 |
+
def apply_optimizations(option=None):
|
52 |
+
global current_optimizer
|
53 |
+
|
54 |
+
undo_optimizations()
|
55 |
+
|
56 |
+
if len(optimizers) == 0:
|
57 |
+
# a script can access the model very early, and optimizations would not be filled by then
|
58 |
+
current_optimizer = None
|
59 |
+
return ''
|
60 |
+
|
61 |
+
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
62 |
+
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
63 |
+
|
64 |
+
sgm.modules.diffusionmodules.model.nonlinearity = silu
|
65 |
+
sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
66 |
+
|
67 |
+
if current_optimizer is not None:
|
68 |
+
current_optimizer.undo()
|
69 |
+
current_optimizer = None
|
70 |
+
|
71 |
+
selection = option or shared.opts.cross_attention_optimization
|
72 |
+
if selection == "Automatic" and len(optimizers) > 0:
|
73 |
+
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
|
74 |
+
else:
|
75 |
+
matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
|
76 |
+
|
77 |
+
if selection == "None":
|
78 |
+
matching_optimizer = None
|
79 |
+
elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
|
80 |
+
matching_optimizer = None
|
81 |
+
elif matching_optimizer is None:
|
82 |
+
matching_optimizer = optimizers[0]
|
83 |
+
|
84 |
+
if matching_optimizer is not None:
|
85 |
+
print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
|
86 |
+
matching_optimizer.apply()
|
87 |
+
print("done.")
|
88 |
+
current_optimizer = matching_optimizer
|
89 |
+
return current_optimizer.name
|
90 |
+
else:
|
91 |
+
print("Disabling attention optimization")
|
92 |
+
return ''
|
93 |
+
|
94 |
+
|
95 |
+
def undo_optimizations():
|
96 |
+
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
97 |
+
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
98 |
+
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
99 |
+
|
100 |
+
sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
101 |
+
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
102 |
+
sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
103 |
+
|
104 |
+
|
105 |
+
def fix_checkpoint():
|
106 |
+
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
107 |
+
checkpoints to be added when not training (there's a warning)"""
|
108 |
+
|
109 |
+
pass
|
110 |
+
|
111 |
+
|
112 |
+
def weighted_loss(sd_model, pred, target, mean=True):
|
113 |
+
#Calculate the weight normally, but ignore the mean
|
114 |
+
loss = sd_model._old_get_loss(pred, target, mean=False)
|
115 |
+
|
116 |
+
#Check if we have weights available
|
117 |
+
weight = getattr(sd_model, '_custom_loss_weight', None)
|
118 |
+
if weight is not None:
|
119 |
+
loss *= weight
|
120 |
+
|
121 |
+
#Return the loss, as mean if specified
|
122 |
+
return loss.mean() if mean else loss
|
123 |
+
|
124 |
+
def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
125 |
+
try:
|
126 |
+
#Temporarily append weights to a place accessible during loss calc
|
127 |
+
sd_model._custom_loss_weight = w
|
128 |
+
|
129 |
+
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
130 |
+
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
131 |
+
if not hasattr(sd_model, '_old_get_loss'):
|
132 |
+
sd_model._old_get_loss = sd_model.get_loss
|
133 |
+
sd_model.get_loss = MethodType(weighted_loss, sd_model)
|
134 |
+
|
135 |
+
#Run the standard forward function, but with the patched 'get_loss'
|
136 |
+
return sd_model.forward(x, c, *args, **kwargs)
|
137 |
+
finally:
|
138 |
+
try:
|
139 |
+
#Delete temporary weights if appended
|
140 |
+
del sd_model._custom_loss_weight
|
141 |
+
except AttributeError:
|
142 |
+
pass
|
143 |
+
|
144 |
+
#If we have an old loss function, reset the loss function to the original one
|
145 |
+
if hasattr(sd_model, '_old_get_loss'):
|
146 |
+
sd_model.get_loss = sd_model._old_get_loss
|
147 |
+
del sd_model._old_get_loss
|
148 |
+
|
149 |
+
def apply_weighted_forward(sd_model):
|
150 |
+
#Add new function 'weighted_forward' that can be called to calc weighted loss
|
151 |
+
sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
|
152 |
+
|
153 |
+
def undo_weighted_forward(sd_model):
|
154 |
+
try:
|
155 |
+
del sd_model.weighted_forward
|
156 |
+
except AttributeError:
|
157 |
+
pass
|
158 |
+
|
159 |
+
|
160 |
+
class StableDiffusionModelHijack:
|
161 |
+
fixes = None
|
162 |
+
layers = None
|
163 |
+
circular_enabled = False
|
164 |
+
clip = None
|
165 |
+
optimization_method = None
|
166 |
+
|
167 |
+
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
168 |
+
|
169 |
+
def __init__(self):
|
170 |
+
self.extra_generation_params = {}
|
171 |
+
self.comments = []
|
172 |
+
|
173 |
+
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
174 |
+
|
175 |
+
def apply_optimizations(self, option=None):
|
176 |
+
try:
|
177 |
+
self.optimization_method = apply_optimizations(option)
|
178 |
+
except Exception as e:
|
179 |
+
errors.display(e, "applying cross attention optimization")
|
180 |
+
undo_optimizations()
|
181 |
+
|
182 |
+
def hijack(self, m):
|
183 |
+
conditioner = getattr(m, 'conditioner', None)
|
184 |
+
if conditioner:
|
185 |
+
text_cond_models = []
|
186 |
+
|
187 |
+
for i in range(len(conditioner.embedders)):
|
188 |
+
embedder = conditioner.embedders[i]
|
189 |
+
typename = type(embedder).__name__
|
190 |
+
if typename == 'FrozenOpenCLIPEmbedder':
|
191 |
+
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
192 |
+
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
|
193 |
+
text_cond_models.append(conditioner.embedders[i])
|
194 |
+
if typename == 'FrozenCLIPEmbedder':
|
195 |
+
model_embeddings = embedder.transformer.text_model.embeddings
|
196 |
+
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
197 |
+
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
198 |
+
text_cond_models.append(conditioner.embedders[i])
|
199 |
+
if typename == 'FrozenOpenCLIPEmbedder2':
|
200 |
+
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
201 |
+
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
202 |
+
text_cond_models.append(conditioner.embedders[i])
|
203 |
+
|
204 |
+
if len(text_cond_models) == 1:
|
205 |
+
m.cond_stage_model = text_cond_models[0]
|
206 |
+
else:
|
207 |
+
m.cond_stage_model = conditioner
|
208 |
+
|
209 |
+
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
210 |
+
model_embeddings = m.cond_stage_model.roberta.embeddings
|
211 |
+
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
212 |
+
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
213 |
+
|
214 |
+
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
215 |
+
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
216 |
+
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
217 |
+
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
218 |
+
|
219 |
+
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
220 |
+
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
221 |
+
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
222 |
+
|
223 |
+
apply_weighted_forward(m)
|
224 |
+
if m.cond_stage_key == "edit":
|
225 |
+
sd_hijack_unet.hijack_ddpm_edit()
|
226 |
+
|
227 |
+
self.apply_optimizations()
|
228 |
+
|
229 |
+
self.clip = m.cond_stage_model
|
230 |
+
|
231 |
+
def flatten(el):
|
232 |
+
flattened = [flatten(children) for children in el.children()]
|
233 |
+
res = [el]
|
234 |
+
for c in flattened:
|
235 |
+
res += c
|
236 |
+
return res
|
237 |
+
|
238 |
+
self.layers = flatten(m)
|
239 |
+
|
240 |
+
if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
|
241 |
+
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
|
242 |
+
|
243 |
+
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
244 |
+
|
245 |
+
def undo_hijack(self, m):
|
246 |
+
if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
|
247 |
+
m.cond_stage_model = m.cond_stage_model.wrapped
|
248 |
+
|
249 |
+
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
250 |
+
m.cond_stage_model = m.cond_stage_model.wrapped
|
251 |
+
|
252 |
+
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
253 |
+
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
254 |
+
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
255 |
+
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
|
256 |
+
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
|
257 |
+
m.cond_stage_model = m.cond_stage_model.wrapped
|
258 |
+
|
259 |
+
undo_optimizations()
|
260 |
+
undo_weighted_forward(m)
|
261 |
+
|
262 |
+
self.apply_circular(False)
|
263 |
+
self.layers = None
|
264 |
+
self.clip = None
|
265 |
+
|
266 |
+
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
|
267 |
+
|
268 |
+
def apply_circular(self, enable):
|
269 |
+
if self.circular_enabled == enable:
|
270 |
+
return
|
271 |
+
|
272 |
+
self.circular_enabled = enable
|
273 |
+
|
274 |
+
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
|
275 |
+
layer.padding_mode = 'circular' if enable else 'zeros'
|
276 |
+
|
277 |
+
def clear_comments(self):
|
278 |
+
self.comments = []
|
279 |
+
self.extra_generation_params = {}
|
280 |
+
|
281 |
+
def get_prompt_lengths(self, text):
|
282 |
+
if self.clip is None:
|
283 |
+
return "-", "-"
|
284 |
+
|
285 |
+
_, token_count = self.clip.process_texts([text])
|
286 |
+
|
287 |
+
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
288 |
+
|
289 |
+
def redo_hijack(self, m):
|
290 |
+
self.undo_hijack(m)
|
291 |
+
self.hijack(m)
|
292 |
+
|
293 |
+
|
294 |
+
class EmbeddingsWithFixes(torch.nn.Module):
|
295 |
+
def __init__(self, wrapped, embeddings):
|
296 |
+
super().__init__()
|
297 |
+
self.wrapped = wrapped
|
298 |
+
self.embeddings = embeddings
|
299 |
+
|
300 |
+
def forward(self, input_ids):
|
301 |
+
batch_fixes = self.embeddings.fixes
|
302 |
+
self.embeddings.fixes = None
|
303 |
+
|
304 |
+
inputs_embeds = self.wrapped(input_ids)
|
305 |
+
|
306 |
+
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
307 |
+
return inputs_embeds
|
308 |
+
|
309 |
+
vecs = []
|
310 |
+
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
311 |
+
for offset, embedding in fixes:
|
312 |
+
emb = devices.cond_cast_unet(embedding.vec)
|
313 |
+
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
314 |
+
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
315 |
+
|
316 |
+
vecs.append(tensor)
|
317 |
+
|
318 |
+
return torch.stack(vecs)
|
319 |
+
|
320 |
+
|
321 |
+
def add_circular_option_to_conv_2d():
|
322 |
+
conv2d_constructor = torch.nn.Conv2d.__init__
|
323 |
+
|
324 |
+
def conv2d_constructor_circular(self, *args, **kwargs):
|
325 |
+
return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
326 |
+
|
327 |
+
torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
328 |
+
|
329 |
+
|
330 |
+
model_hijack = StableDiffusionModelHijack()
|
331 |
+
|
332 |
+
|
333 |
+
def register_buffer(self, name, attr):
|
334 |
+
"""
|
335 |
+
Fix register buffer bug for Mac OS.
|
336 |
+
"""
|
337 |
+
|
338 |
+
if type(attr) == torch.Tensor:
|
339 |
+
if attr.device != devices.device:
|
340 |
+
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
341 |
+
|
342 |
+
setattr(self, name, attr)
|
343 |
+
|
344 |
+
|
345 |
+
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
346 |
+
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
modules/sd_hijack_checkpoint.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.checkpoint import checkpoint
|
2 |
+
|
3 |
+
import ldm.modules.attention
|
4 |
+
import ldm.modules.diffusionmodules.openaimodel
|
5 |
+
|
6 |
+
|
7 |
+
def BasicTransformerBlock_forward(self, x, context=None):
|
8 |
+
return checkpoint(self._forward, x, context)
|
9 |
+
|
10 |
+
|
11 |
+
def AttentionBlock_forward(self, x):
|
12 |
+
return checkpoint(self._forward, x)
|
13 |
+
|
14 |
+
|
15 |
+
def ResBlock_forward(self, x, emb):
|
16 |
+
return checkpoint(self._forward, x, emb)
|
17 |
+
|
18 |
+
|
19 |
+
stored = []
|
20 |
+
|
21 |
+
|
22 |
+
def add():
|
23 |
+
if len(stored) != 0:
|
24 |
+
return
|
25 |
+
|
26 |
+
stored.extend([
|
27 |
+
ldm.modules.attention.BasicTransformerBlock.forward,
|
28 |
+
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
|
29 |
+
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
|
30 |
+
])
|
31 |
+
|
32 |
+
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
|
33 |
+
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
|
34 |
+
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
|
35 |
+
|
36 |
+
|
37 |
+
def remove():
|
38 |
+
if len(stored) == 0:
|
39 |
+
return
|
40 |
+
|
41 |
+
ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
|
42 |
+
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
|
43 |
+
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
|
44 |
+
|
45 |
+
stored.clear()
|
46 |
+
|
modules/sd_hijack_clip.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from collections import namedtuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from modules import prompt_parser, devices, sd_hijack
|
7 |
+
from modules.shared import opts
|
8 |
+
|
9 |
+
|
10 |
+
class PromptChunk:
|
11 |
+
"""
|
12 |
+
This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
|
13 |
+
If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
|
14 |
+
Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
|
15 |
+
so just 75 tokens from prompt.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self):
|
19 |
+
self.tokens = []
|
20 |
+
self.multipliers = []
|
21 |
+
self.fixes = []
|
22 |
+
|
23 |
+
|
24 |
+
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
25 |
+
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
|
26 |
+
chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
|
27 |
+
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
28 |
+
|
29 |
+
|
30 |
+
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
31 |
+
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
|
32 |
+
have unlimited prompt length and assign weights to tokens in prompt.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, wrapped, hijack):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.wrapped = wrapped
|
39 |
+
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
|
40 |
+
depending on model."""
|
41 |
+
|
42 |
+
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
43 |
+
self.chunk_length = 75
|
44 |
+
|
45 |
+
self.is_trainable = getattr(wrapped, 'is_trainable', False)
|
46 |
+
self.input_key = getattr(wrapped, 'input_key', 'txt')
|
47 |
+
self.legacy_ucg_val = None
|
48 |
+
|
49 |
+
def empty_chunk(self):
|
50 |
+
"""creates an empty PromptChunk and returns it"""
|
51 |
+
|
52 |
+
chunk = PromptChunk()
|
53 |
+
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
54 |
+
chunk.multipliers = [1.0] * (self.chunk_length + 2)
|
55 |
+
return chunk
|
56 |
+
|
57 |
+
def get_target_prompt_token_count(self, token_count):
|
58 |
+
"""returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
|
59 |
+
|
60 |
+
return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
|
61 |
+
|
62 |
+
def tokenize(self, texts):
|
63 |
+
"""Converts a batch of texts into a batch of token ids"""
|
64 |
+
|
65 |
+
raise NotImplementedError
|
66 |
+
|
67 |
+
def encode_with_transformers(self, tokens):
|
68 |
+
"""
|
69 |
+
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
|
70 |
+
All python lists with tokens are assumed to have same length, usually 77.
|
71 |
+
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
|
72 |
+
model - can be 768 and 1024.
|
73 |
+
Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
|
74 |
+
"""
|
75 |
+
|
76 |
+
raise NotImplementedError
|
77 |
+
|
78 |
+
def encode_embedding_init_text(self, init_text, nvpt):
|
79 |
+
"""Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
|
80 |
+
transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
|
81 |
+
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
def tokenize_line(self, line):
|
85 |
+
"""
|
86 |
+
this transforms a single prompt into a list of PromptChunk objects - as many as needed to
|
87 |
+
represent the prompt.
|
88 |
+
Returns the list and the total number of tokens in the prompt.
|
89 |
+
"""
|
90 |
+
|
91 |
+
if opts.enable_emphasis:
|
92 |
+
parsed = prompt_parser.parse_prompt_attention(line)
|
93 |
+
else:
|
94 |
+
parsed = [[line, 1.0]]
|
95 |
+
|
96 |
+
tokenized = self.tokenize([text for text, _ in parsed])
|
97 |
+
|
98 |
+
chunks = []
|
99 |
+
chunk = PromptChunk()
|
100 |
+
token_count = 0
|
101 |
+
last_comma = -1
|
102 |
+
|
103 |
+
def next_chunk(is_last=False):
|
104 |
+
"""puts current chunk into the list of results and produces the next one - empty;
|
105 |
+
if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
|
106 |
+
nonlocal token_count
|
107 |
+
nonlocal last_comma
|
108 |
+
nonlocal chunk
|
109 |
+
|
110 |
+
if is_last:
|
111 |
+
token_count += len(chunk.tokens)
|
112 |
+
else:
|
113 |
+
token_count += self.chunk_length
|
114 |
+
|
115 |
+
to_add = self.chunk_length - len(chunk.tokens)
|
116 |
+
if to_add > 0:
|
117 |
+
chunk.tokens += [self.id_end] * to_add
|
118 |
+
chunk.multipliers += [1.0] * to_add
|
119 |
+
|
120 |
+
chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
|
121 |
+
chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
|
122 |
+
|
123 |
+
last_comma = -1
|
124 |
+
chunks.append(chunk)
|
125 |
+
chunk = PromptChunk()
|
126 |
+
|
127 |
+
for tokens, (text, weight) in zip(tokenized, parsed):
|
128 |
+
if text == 'BREAK' and weight == -1:
|
129 |
+
next_chunk()
|
130 |
+
continue
|
131 |
+
|
132 |
+
position = 0
|
133 |
+
while position < len(tokens):
|
134 |
+
token = tokens[position]
|
135 |
+
|
136 |
+
if token == self.comma_token:
|
137 |
+
last_comma = len(chunk.tokens)
|
138 |
+
|
139 |
+
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
140 |
+
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
|
141 |
+
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
142 |
+
break_location = last_comma + 1
|
143 |
+
|
144 |
+
reloc_tokens = chunk.tokens[break_location:]
|
145 |
+
reloc_mults = chunk.multipliers[break_location:]
|
146 |
+
|
147 |
+
chunk.tokens = chunk.tokens[:break_location]
|
148 |
+
chunk.multipliers = chunk.multipliers[:break_location]
|
149 |
+
|
150 |
+
next_chunk()
|
151 |
+
chunk.tokens = reloc_tokens
|
152 |
+
chunk.multipliers = reloc_mults
|
153 |
+
|
154 |
+
if len(chunk.tokens) == self.chunk_length:
|
155 |
+
next_chunk()
|
156 |
+
|
157 |
+
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
|
158 |
+
if embedding is None:
|
159 |
+
chunk.tokens.append(token)
|
160 |
+
chunk.multipliers.append(weight)
|
161 |
+
position += 1
|
162 |
+
continue
|
163 |
+
|
164 |
+
emb_len = int(embedding.vec.shape[0])
|
165 |
+
if len(chunk.tokens) + emb_len > self.chunk_length:
|
166 |
+
next_chunk()
|
167 |
+
|
168 |
+
chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
|
169 |
+
|
170 |
+
chunk.tokens += [0] * emb_len
|
171 |
+
chunk.multipliers += [weight] * emb_len
|
172 |
+
position += embedding_length_in_tokens
|
173 |
+
|
174 |
+
if chunk.tokens or not chunks:
|
175 |
+
next_chunk(is_last=True)
|
176 |
+
|
177 |
+
return chunks, token_count
|
178 |
+
|
179 |
+
def process_texts(self, texts):
|
180 |
+
"""
|
181 |
+
Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
|
182 |
+
length, in tokens, of all texts.
|
183 |
+
"""
|
184 |
+
|
185 |
+
token_count = 0
|
186 |
+
|
187 |
+
cache = {}
|
188 |
+
batch_chunks = []
|
189 |
+
for line in texts:
|
190 |
+
if line in cache:
|
191 |
+
chunks = cache[line]
|
192 |
+
else:
|
193 |
+
chunks, current_token_count = self.tokenize_line(line)
|
194 |
+
token_count = max(current_token_count, token_count)
|
195 |
+
|
196 |
+
cache[line] = chunks
|
197 |
+
|
198 |
+
batch_chunks.append(chunks)
|
199 |
+
|
200 |
+
return batch_chunks, token_count
|
201 |
+
|
202 |
+
def forward(self, texts):
|
203 |
+
"""
|
204 |
+
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
205 |
+
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
206 |
+
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
|
207 |
+
An example shape returned by this function can be: (2, 77, 768).
|
208 |
+
For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
|
209 |
+
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
210 |
+
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
211 |
+
"""
|
212 |
+
|
213 |
+
if opts.use_old_emphasis_implementation:
|
214 |
+
import modules.sd_hijack_clip_old
|
215 |
+
return modules.sd_hijack_clip_old.forward_old(self, texts)
|
216 |
+
|
217 |
+
batch_chunks, token_count = self.process_texts(texts)
|
218 |
+
|
219 |
+
used_embeddings = {}
|
220 |
+
chunk_count = max([len(x) for x in batch_chunks])
|
221 |
+
|
222 |
+
zs = []
|
223 |
+
for i in range(chunk_count):
|
224 |
+
batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
|
225 |
+
|
226 |
+
tokens = [x.tokens for x in batch_chunk]
|
227 |
+
multipliers = [x.multipliers for x in batch_chunk]
|
228 |
+
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
229 |
+
|
230 |
+
for fixes in self.hijack.fixes:
|
231 |
+
for _position, embedding in fixes:
|
232 |
+
used_embeddings[embedding.name] = embedding
|
233 |
+
|
234 |
+
z = self.process_tokens(tokens, multipliers)
|
235 |
+
zs.append(z)
|
236 |
+
|
237 |
+
if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:
|
238 |
+
hashes = []
|
239 |
+
for name, embedding in used_embeddings.items():
|
240 |
+
shorthash = embedding.shorthash
|
241 |
+
if not shorthash:
|
242 |
+
continue
|
243 |
+
|
244 |
+
name = name.replace(":", "").replace(",", "")
|
245 |
+
hashes.append(f"{name}: {shorthash}")
|
246 |
+
|
247 |
+
if hashes:
|
248 |
+
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
249 |
+
|
250 |
+
if getattr(self.wrapped, 'return_pooled', False):
|
251 |
+
return torch.hstack(zs), zs[0].pooled
|
252 |
+
else:
|
253 |
+
return torch.hstack(zs)
|
254 |
+
|
255 |
+
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
256 |
+
"""
|
257 |
+
sends one single prompt chunk to be encoded by transformers neural network.
|
258 |
+
remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
|
259 |
+
there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
|
260 |
+
Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
|
261 |
+
corresponds to one token.
|
262 |
+
"""
|
263 |
+
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
264 |
+
|
265 |
+
# this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
|
266 |
+
if self.id_end != self.id_pad:
|
267 |
+
for batch_pos in range(len(remade_batch_tokens)):
|
268 |
+
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
269 |
+
tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
|
270 |
+
|
271 |
+
z = self.encode_with_transformers(tokens)
|
272 |
+
|
273 |
+
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
274 |
+
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
275 |
+
original_mean = z.mean()
|
276 |
+
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
277 |
+
new_mean = z.mean()
|
278 |
+
z *= (original_mean / new_mean)
|
279 |
+
|
280 |
+
return z
|
281 |
+
|
282 |
+
|
283 |
+
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
284 |
+
def __init__(self, wrapped, hijack):
|
285 |
+
super().__init__(wrapped, hijack)
|
286 |
+
self.tokenizer = wrapped.tokenizer
|
287 |
+
|
288 |
+
vocab = self.tokenizer.get_vocab()
|
289 |
+
|
290 |
+
self.comma_token = vocab.get(',</w>', None)
|
291 |
+
|
292 |
+
self.token_mults = {}
|
293 |
+
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
294 |
+
for text, ident in tokens_with_parens:
|
295 |
+
mult = 1.0
|
296 |
+
for c in text:
|
297 |
+
if c == '[':
|
298 |
+
mult /= 1.1
|
299 |
+
if c == ']':
|
300 |
+
mult *= 1.1
|
301 |
+
if c == '(':
|
302 |
+
mult *= 1.1
|
303 |
+
if c == ')':
|
304 |
+
mult /= 1.1
|
305 |
+
|
306 |
+
if mult != 1.0:
|
307 |
+
self.token_mults[ident] = mult
|
308 |
+
|
309 |
+
self.id_start = self.wrapped.tokenizer.bos_token_id
|
310 |
+
self.id_end = self.wrapped.tokenizer.eos_token_id
|
311 |
+
self.id_pad = self.id_end
|
312 |
+
|
313 |
+
def tokenize(self, texts):
|
314 |
+
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
315 |
+
|
316 |
+
return tokenized
|
317 |
+
|
318 |
+
def encode_with_transformers(self, tokens):
|
319 |
+
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
320 |
+
|
321 |
+
if opts.CLIP_stop_at_last_layers > 1:
|
322 |
+
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
323 |
+
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
324 |
+
else:
|
325 |
+
z = outputs.last_hidden_state
|
326 |
+
|
327 |
+
return z
|
328 |
+
|
329 |
+
def encode_embedding_init_text(self, init_text, nvpt):
|
330 |
+
embedding_layer = self.wrapped.transformer.text_model.embeddings
|
331 |
+
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
332 |
+
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
333 |
+
|
334 |
+
return embedded
|
335 |
+
|
336 |
+
|
337 |
+
class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
|
338 |
+
def __init__(self, wrapped, hijack):
|
339 |
+
super().__init__(wrapped, hijack)
|
340 |
+
|
341 |
+
def encode_with_transformers(self, tokens):
|
342 |
+
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")
|
343 |
+
|
344 |
+
if self.wrapped.layer == "last":
|
345 |
+
z = outputs.last_hidden_state
|
346 |
+
else:
|
347 |
+
z = outputs.hidden_states[self.wrapped.layer_idx]
|
348 |
+
|
349 |
+
return z
|
modules/sd_hijack_clip_old.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules import sd_hijack_clip
|
2 |
+
from modules import shared
|
3 |
+
|
4 |
+
|
5 |
+
def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
6 |
+
id_start = self.id_start
|
7 |
+
id_end = self.id_end
|
8 |
+
maxlen = self.wrapped.max_length # you get to stay at 77
|
9 |
+
used_custom_terms = []
|
10 |
+
remade_batch_tokens = []
|
11 |
+
hijack_comments = []
|
12 |
+
hijack_fixes = []
|
13 |
+
token_count = 0
|
14 |
+
|
15 |
+
cache = {}
|
16 |
+
batch_tokens = self.tokenize(texts)
|
17 |
+
batch_multipliers = []
|
18 |
+
for tokens in batch_tokens:
|
19 |
+
tuple_tokens = tuple(tokens)
|
20 |
+
|
21 |
+
if tuple_tokens in cache:
|
22 |
+
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
23 |
+
else:
|
24 |
+
fixes = []
|
25 |
+
remade_tokens = []
|
26 |
+
multipliers = []
|
27 |
+
mult = 1.0
|
28 |
+
|
29 |
+
i = 0
|
30 |
+
while i < len(tokens):
|
31 |
+
token = tokens[i]
|
32 |
+
|
33 |
+
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
34 |
+
|
35 |
+
mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
|
36 |
+
if mult_change is not None:
|
37 |
+
mult *= mult_change
|
38 |
+
i += 1
|
39 |
+
elif embedding is None:
|
40 |
+
remade_tokens.append(token)
|
41 |
+
multipliers.append(mult)
|
42 |
+
i += 1
|
43 |
+
else:
|
44 |
+
emb_len = int(embedding.vec.shape[0])
|
45 |
+
fixes.append((len(remade_tokens), embedding))
|
46 |
+
remade_tokens += [0] * emb_len
|
47 |
+
multipliers += [mult] * emb_len
|
48 |
+
used_custom_terms.append((embedding.name, embedding.checksum()))
|
49 |
+
i += embedding_length_in_tokens
|
50 |
+
|
51 |
+
if len(remade_tokens) > maxlen - 2:
|
52 |
+
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
53 |
+
ovf = remade_tokens[maxlen - 2:]
|
54 |
+
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
55 |
+
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
56 |
+
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
57 |
+
|
58 |
+
token_count = len(remade_tokens)
|
59 |
+
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
60 |
+
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
61 |
+
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
62 |
+
|
63 |
+
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
64 |
+
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
65 |
+
|
66 |
+
remade_batch_tokens.append(remade_tokens)
|
67 |
+
hijack_fixes.append(fixes)
|
68 |
+
batch_multipliers.append(multipliers)
|
69 |
+
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
70 |
+
|
71 |
+
|
72 |
+
def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
73 |
+
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
|
74 |
+
|
75 |
+
self.hijack.comments += hijack_comments
|
76 |
+
|
77 |
+
if used_custom_terms:
|
78 |
+
embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
|
79 |
+
self.hijack.comments.append(f"Used embeddings: {embedding_names}")
|
80 |
+
|
81 |
+
self.hijack.fixes = hijack_fixes
|
82 |
+
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
modules/sd_hijack_inpainting.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import ldm.models.diffusion.ddpm
|
4 |
+
import ldm.models.diffusion.ddim
|
5 |
+
import ldm.models.diffusion.plms
|
6 |
+
|
7 |
+
from ldm.models.diffusion.ddim import noise_like
|
8 |
+
from ldm.models.diffusion.sampling_util import norm_thresholding
|
9 |
+
|
10 |
+
|
11 |
+
@torch.no_grad()
|
12 |
+
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
13 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
14 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
15 |
+
b, *_, device = *x.shape, x.device
|
16 |
+
|
17 |
+
def get_model_output(x, t):
|
18 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
19 |
+
e_t = self.model.apply_model(x, t, c)
|
20 |
+
else:
|
21 |
+
x_in = torch.cat([x] * 2)
|
22 |
+
t_in = torch.cat([t] * 2)
|
23 |
+
|
24 |
+
if isinstance(c, dict):
|
25 |
+
assert isinstance(unconditional_conditioning, dict)
|
26 |
+
c_in = {}
|
27 |
+
for k in c:
|
28 |
+
if isinstance(c[k], list):
|
29 |
+
c_in[k] = [
|
30 |
+
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
31 |
+
for i in range(len(c[k]))
|
32 |
+
]
|
33 |
+
else:
|
34 |
+
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
35 |
+
else:
|
36 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
37 |
+
|
38 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
39 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
40 |
+
|
41 |
+
if score_corrector is not None:
|
42 |
+
assert self.model.parameterization == "eps"
|
43 |
+
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
44 |
+
|
45 |
+
return e_t
|
46 |
+
|
47 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
48 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
49 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
50 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
51 |
+
|
52 |
+
def get_x_prev_and_pred_x0(e_t, index):
|
53 |
+
# select parameters corresponding to the currently considered timestep
|
54 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
55 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
56 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
57 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
58 |
+
|
59 |
+
# current prediction for x_0
|
60 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
61 |
+
if quantize_denoised:
|
62 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
63 |
+
if dynamic_threshold is not None:
|
64 |
+
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
65 |
+
# direction pointing to x_t
|
66 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
67 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
68 |
+
if noise_dropout > 0.:
|
69 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
70 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
71 |
+
return x_prev, pred_x0
|
72 |
+
|
73 |
+
e_t = get_model_output(x, t)
|
74 |
+
if len(old_eps) == 0:
|
75 |
+
# Pseudo Improved Euler (2nd order)
|
76 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
77 |
+
e_t_next = get_model_output(x_prev, t_next)
|
78 |
+
e_t_prime = (e_t + e_t_next) / 2
|
79 |
+
elif len(old_eps) == 1:
|
80 |
+
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
81 |
+
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
82 |
+
elif len(old_eps) == 2:
|
83 |
+
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
84 |
+
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
85 |
+
elif len(old_eps) >= 3:
|
86 |
+
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
87 |
+
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
88 |
+
|
89 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
90 |
+
|
91 |
+
return x_prev, pred_x0, e_t
|
92 |
+
|
93 |
+
|
94 |
+
def do_inpainting_hijack():
|
95 |
+
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
96 |
+
|
97 |
+
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
modules/sd_hijack_ip2p.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
|
3 |
+
|
4 |
+
def should_hijack_ip2p(checkpoint_info):
|
5 |
+
from modules import sd_models_config
|
6 |
+
|
7 |
+
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
8 |
+
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
9 |
+
|
10 |
+
return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename
|
modules/sd_hijack_open_clip.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import open_clip.tokenizer
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from modules import sd_hijack_clip, devices
|
5 |
+
from modules.shared import opts
|
6 |
+
|
7 |
+
tokenizer = open_clip.tokenizer._tokenizer
|
8 |
+
|
9 |
+
|
10 |
+
class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
11 |
+
def __init__(self, wrapped, hijack):
|
12 |
+
super().__init__(wrapped, hijack)
|
13 |
+
|
14 |
+
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
15 |
+
self.id_start = tokenizer.encoder["<start_of_text>"]
|
16 |
+
self.id_end = tokenizer.encoder["<end_of_text>"]
|
17 |
+
self.id_pad = 0
|
18 |
+
|
19 |
+
def tokenize(self, texts):
|
20 |
+
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
21 |
+
|
22 |
+
tokenized = [tokenizer.encode(text) for text in texts]
|
23 |
+
|
24 |
+
return tokenized
|
25 |
+
|
26 |
+
def encode_with_transformers(self, tokens):
|
27 |
+
# set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
|
28 |
+
z = self.wrapped.encode_with_transformer(tokens)
|
29 |
+
|
30 |
+
return z
|
31 |
+
|
32 |
+
def encode_embedding_init_text(self, init_text, nvpt):
|
33 |
+
ids = tokenizer.encode(init_text)
|
34 |
+
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
35 |
+
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
36 |
+
|
37 |
+
return embedded
|
38 |
+
|
39 |
+
|
40 |
+
class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
41 |
+
def __init__(self, wrapped, hijack):
|
42 |
+
super().__init__(wrapped, hijack)
|
43 |
+
|
44 |
+
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
45 |
+
self.id_start = tokenizer.encoder["<start_of_text>"]
|
46 |
+
self.id_end = tokenizer.encoder["<end_of_text>"]
|
47 |
+
self.id_pad = 0
|
48 |
+
|
49 |
+
def tokenize(self, texts):
|
50 |
+
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
51 |
+
|
52 |
+
tokenized = [tokenizer.encode(text) for text in texts]
|
53 |
+
|
54 |
+
return tokenized
|
55 |
+
|
56 |
+
def encode_with_transformers(self, tokens):
|
57 |
+
d = self.wrapped.encode_with_transformer(tokens)
|
58 |
+
z = d[self.wrapped.layer]
|
59 |
+
|
60 |
+
pooled = d.get("pooled")
|
61 |
+
if pooled is not None:
|
62 |
+
z.pooled = pooled
|
63 |
+
|
64 |
+
return z
|
65 |
+
|
66 |
+
def encode_embedding_init_text(self, init_text, nvpt):
|
67 |
+
ids = tokenizer.encode(init_text)
|
68 |
+
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
69 |
+
embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
|
70 |
+
|
71 |
+
return embedded
|
modules/sd_hijack_optimizations.py
ADDED
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import math
|
3 |
+
import psutil
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import einsum
|
7 |
+
|
8 |
+
from ldm.util import default
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
from modules import shared, errors, devices, sub_quadratic_attention
|
12 |
+
from modules.hypernetworks import hypernetwork
|
13 |
+
|
14 |
+
import ldm.modules.attention
|
15 |
+
import ldm.modules.diffusionmodules.model
|
16 |
+
|
17 |
+
import sgm.modules.attention
|
18 |
+
import sgm.modules.diffusionmodules.model
|
19 |
+
|
20 |
+
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
21 |
+
sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
|
22 |
+
|
23 |
+
|
24 |
+
class SdOptimization:
|
25 |
+
name: str = None
|
26 |
+
label: str | None = None
|
27 |
+
cmd_opt: str | None = None
|
28 |
+
priority: int = 0
|
29 |
+
|
30 |
+
def title(self):
|
31 |
+
if self.label is None:
|
32 |
+
return self.name
|
33 |
+
|
34 |
+
return f"{self.name} - {self.label}"
|
35 |
+
|
36 |
+
def is_available(self):
|
37 |
+
return True
|
38 |
+
|
39 |
+
def apply(self):
|
40 |
+
pass
|
41 |
+
|
42 |
+
def undo(self):
|
43 |
+
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
44 |
+
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
45 |
+
|
46 |
+
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
47 |
+
sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
|
48 |
+
|
49 |
+
|
50 |
+
class SdOptimizationXformers(SdOptimization):
|
51 |
+
name = "xformers"
|
52 |
+
cmd_opt = "xformers"
|
53 |
+
priority = 100
|
54 |
+
|
55 |
+
def is_available(self):
|
56 |
+
return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
|
57 |
+
|
58 |
+
def apply(self):
|
59 |
+
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
60 |
+
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
61 |
+
sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
62 |
+
sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
63 |
+
|
64 |
+
|
65 |
+
class SdOptimizationSdpNoMem(SdOptimization):
|
66 |
+
name = "sdp-no-mem"
|
67 |
+
label = "scaled dot product without memory efficient attention"
|
68 |
+
cmd_opt = "opt_sdp_no_mem_attention"
|
69 |
+
priority = 80
|
70 |
+
|
71 |
+
def is_available(self):
|
72 |
+
return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
|
73 |
+
|
74 |
+
def apply(self):
|
75 |
+
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
76 |
+
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
77 |
+
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
78 |
+
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
79 |
+
|
80 |
+
|
81 |
+
class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
82 |
+
name = "sdp"
|
83 |
+
label = "scaled dot product"
|
84 |
+
cmd_opt = "opt_sdp_attention"
|
85 |
+
priority = 70
|
86 |
+
|
87 |
+
def apply(self):
|
88 |
+
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
89 |
+
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
90 |
+
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
91 |
+
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
92 |
+
|
93 |
+
|
94 |
+
class SdOptimizationSubQuad(SdOptimization):
|
95 |
+
name = "sub-quadratic"
|
96 |
+
cmd_opt = "opt_sub_quad_attention"
|
97 |
+
priority = 10
|
98 |
+
|
99 |
+
def apply(self):
|
100 |
+
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
101 |
+
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
102 |
+
sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
103 |
+
sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
104 |
+
|
105 |
+
|
106 |
+
class SdOptimizationV1(SdOptimization):
|
107 |
+
name = "V1"
|
108 |
+
label = "original v1"
|
109 |
+
cmd_opt = "opt_split_attention_v1"
|
110 |
+
priority = 10
|
111 |
+
|
112 |
+
def apply(self):
|
113 |
+
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
114 |
+
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
115 |
+
|
116 |
+
|
117 |
+
class SdOptimizationInvokeAI(SdOptimization):
|
118 |
+
name = "InvokeAI"
|
119 |
+
cmd_opt = "opt_split_attention_invokeai"
|
120 |
+
|
121 |
+
@property
|
122 |
+
def priority(self):
|
123 |
+
return 1000 if not torch.cuda.is_available() else 10
|
124 |
+
|
125 |
+
def apply(self):
|
126 |
+
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
127 |
+
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
128 |
+
|
129 |
+
|
130 |
+
class SdOptimizationDoggettx(SdOptimization):
|
131 |
+
name = "Doggettx"
|
132 |
+
cmd_opt = "opt_split_attention"
|
133 |
+
priority = 90
|
134 |
+
|
135 |
+
def apply(self):
|
136 |
+
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
137 |
+
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
138 |
+
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
139 |
+
sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
140 |
+
|
141 |
+
|
142 |
+
def list_optimizers(res):
|
143 |
+
res.extend([
|
144 |
+
SdOptimizationXformers(),
|
145 |
+
SdOptimizationSdpNoMem(),
|
146 |
+
SdOptimizationSdp(),
|
147 |
+
SdOptimizationSubQuad(),
|
148 |
+
SdOptimizationV1(),
|
149 |
+
SdOptimizationInvokeAI(),
|
150 |
+
SdOptimizationDoggettx(),
|
151 |
+
])
|
152 |
+
|
153 |
+
|
154 |
+
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
155 |
+
try:
|
156 |
+
import xformers.ops
|
157 |
+
shared.xformers_available = True
|
158 |
+
except Exception:
|
159 |
+
errors.report("Cannot import xformers", exc_info=True)
|
160 |
+
|
161 |
+
|
162 |
+
def get_available_vram():
|
163 |
+
if shared.device.type == 'cuda':
|
164 |
+
stats = torch.cuda.memory_stats(shared.device)
|
165 |
+
mem_active = stats['active_bytes.all.current']
|
166 |
+
mem_reserved = stats['reserved_bytes.all.current']
|
167 |
+
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
168 |
+
mem_free_torch = mem_reserved - mem_active
|
169 |
+
mem_free_total = mem_free_cuda + mem_free_torch
|
170 |
+
return mem_free_total
|
171 |
+
else:
|
172 |
+
return psutil.virtual_memory().available
|
173 |
+
|
174 |
+
|
175 |
+
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
176 |
+
def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
|
177 |
+
h = self.heads
|
178 |
+
|
179 |
+
q_in = self.to_q(x)
|
180 |
+
context = default(context, x)
|
181 |
+
|
182 |
+
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
183 |
+
k_in = self.to_k(context_k)
|
184 |
+
v_in = self.to_v(context_v)
|
185 |
+
del context, context_k, context_v, x
|
186 |
+
|
187 |
+
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
188 |
+
del q_in, k_in, v_in
|
189 |
+
|
190 |
+
dtype = q.dtype
|
191 |
+
if shared.opts.upcast_attn:
|
192 |
+
q, k, v = q.float(), k.float(), v.float()
|
193 |
+
|
194 |
+
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
195 |
+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
196 |
+
for i in range(0, q.shape[0], 2):
|
197 |
+
end = i + 2
|
198 |
+
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
199 |
+
s1 *= self.scale
|
200 |
+
|
201 |
+
s2 = s1.softmax(dim=-1)
|
202 |
+
del s1
|
203 |
+
|
204 |
+
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
205 |
+
del s2
|
206 |
+
del q, k, v
|
207 |
+
|
208 |
+
r1 = r1.to(dtype)
|
209 |
+
|
210 |
+
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
211 |
+
del r1
|
212 |
+
|
213 |
+
return self.to_out(r2)
|
214 |
+
|
215 |
+
|
216 |
+
# taken from https://github.com/Doggettx/stable-diffusion and modified
|
217 |
+
def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
|
218 |
+
h = self.heads
|
219 |
+
|
220 |
+
q_in = self.to_q(x)
|
221 |
+
context = default(context, x)
|
222 |
+
|
223 |
+
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
224 |
+
k_in = self.to_k(context_k)
|
225 |
+
v_in = self.to_v(context_v)
|
226 |
+
|
227 |
+
dtype = q_in.dtype
|
228 |
+
if shared.opts.upcast_attn:
|
229 |
+
q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
|
230 |
+
|
231 |
+
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
232 |
+
k_in = k_in * self.scale
|
233 |
+
|
234 |
+
del context, x
|
235 |
+
|
236 |
+
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
237 |
+
del q_in, k_in, v_in
|
238 |
+
|
239 |
+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
240 |
+
|
241 |
+
mem_free_total = get_available_vram()
|
242 |
+
|
243 |
+
gb = 1024 ** 3
|
244 |
+
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
245 |
+
modifier = 3 if q.element_size() == 2 else 2.5
|
246 |
+
mem_required = tensor_size * modifier
|
247 |
+
steps = 1
|
248 |
+
|
249 |
+
if mem_required > mem_free_total:
|
250 |
+
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
251 |
+
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
252 |
+
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
253 |
+
|
254 |
+
if steps > 64:
|
255 |
+
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
256 |
+
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
257 |
+
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
258 |
+
|
259 |
+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
260 |
+
for i in range(0, q.shape[1], slice_size):
|
261 |
+
end = i + slice_size
|
262 |
+
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
263 |
+
|
264 |
+
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
265 |
+
del s1
|
266 |
+
|
267 |
+
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
268 |
+
del s2
|
269 |
+
|
270 |
+
del q, k, v
|
271 |
+
|
272 |
+
r1 = r1.to(dtype)
|
273 |
+
|
274 |
+
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
275 |
+
del r1
|
276 |
+
|
277 |
+
return self.to_out(r2)
|
278 |
+
|
279 |
+
|
280 |
+
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
281 |
+
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
282 |
+
|
283 |
+
|
284 |
+
def einsum_op_compvis(q, k, v):
|
285 |
+
s = einsum('b i d, b j d -> b i j', q, k)
|
286 |
+
s = s.softmax(dim=-1, dtype=s.dtype)
|
287 |
+
return einsum('b i j, b j d -> b i d', s, v)
|
288 |
+
|
289 |
+
|
290 |
+
def einsum_op_slice_0(q, k, v, slice_size):
|
291 |
+
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
292 |
+
for i in range(0, q.shape[0], slice_size):
|
293 |
+
end = i + slice_size
|
294 |
+
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
295 |
+
return r
|
296 |
+
|
297 |
+
|
298 |
+
def einsum_op_slice_1(q, k, v, slice_size):
|
299 |
+
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
300 |
+
for i in range(0, q.shape[1], slice_size):
|
301 |
+
end = i + slice_size
|
302 |
+
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
303 |
+
return r
|
304 |
+
|
305 |
+
|
306 |
+
def einsum_op_mps_v1(q, k, v):
|
307 |
+
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
308 |
+
return einsum_op_compvis(q, k, v)
|
309 |
+
else:
|
310 |
+
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
311 |
+
if slice_size % 4096 == 0:
|
312 |
+
slice_size -= 1
|
313 |
+
return einsum_op_slice_1(q, k, v, slice_size)
|
314 |
+
|
315 |
+
|
316 |
+
def einsum_op_mps_v2(q, k, v):
|
317 |
+
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
318 |
+
return einsum_op_compvis(q, k, v)
|
319 |
+
else:
|
320 |
+
return einsum_op_slice_0(q, k, v, 1)
|
321 |
+
|
322 |
+
|
323 |
+
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
324 |
+
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
325 |
+
if size_mb <= max_tensor_mb:
|
326 |
+
return einsum_op_compvis(q, k, v)
|
327 |
+
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
328 |
+
if div <= q.shape[0]:
|
329 |
+
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
330 |
+
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
331 |
+
|
332 |
+
|
333 |
+
def einsum_op_cuda(q, k, v):
|
334 |
+
stats = torch.cuda.memory_stats(q.device)
|
335 |
+
mem_active = stats['active_bytes.all.current']
|
336 |
+
mem_reserved = stats['reserved_bytes.all.current']
|
337 |
+
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
338 |
+
mem_free_torch = mem_reserved - mem_active
|
339 |
+
mem_free_total = mem_free_cuda + mem_free_torch
|
340 |
+
# Divide factor of safety as there's copying and fragmentation
|
341 |
+
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
342 |
+
|
343 |
+
|
344 |
+
def einsum_op(q, k, v):
|
345 |
+
if q.device.type == 'cuda':
|
346 |
+
return einsum_op_cuda(q, k, v)
|
347 |
+
|
348 |
+
if q.device.type == 'mps':
|
349 |
+
if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
|
350 |
+
return einsum_op_mps_v1(q, k, v)
|
351 |
+
return einsum_op_mps_v2(q, k, v)
|
352 |
+
|
353 |
+
# Smaller slices are faster due to L2/L3/SLC caches.
|
354 |
+
# Tested on i7 with 8MB L3 cache.
|
355 |
+
return einsum_op_tensor_mem(q, k, v, 32)
|
356 |
+
|
357 |
+
|
358 |
+
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
|
359 |
+
h = self.heads
|
360 |
+
|
361 |
+
q = self.to_q(x)
|
362 |
+
context = default(context, x)
|
363 |
+
|
364 |
+
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
365 |
+
k = self.to_k(context_k)
|
366 |
+
v = self.to_v(context_v)
|
367 |
+
del context, context_k, context_v, x
|
368 |
+
|
369 |
+
dtype = q.dtype
|
370 |
+
if shared.opts.upcast_attn:
|
371 |
+
q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
|
372 |
+
|
373 |
+
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
374 |
+
k = k * self.scale
|
375 |
+
|
376 |
+
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
377 |
+
r = einsum_op(q, k, v)
|
378 |
+
r = r.to(dtype)
|
379 |
+
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
380 |
+
|
381 |
+
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
382 |
+
|
383 |
+
|
384 |
+
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
385 |
+
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
386 |
+
def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
|
387 |
+
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
388 |
+
|
389 |
+
h = self.heads
|
390 |
+
|
391 |
+
q = self.to_q(x)
|
392 |
+
context = default(context, x)
|
393 |
+
|
394 |
+
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
395 |
+
k = self.to_k(context_k)
|
396 |
+
v = self.to_v(context_v)
|
397 |
+
del context, context_k, context_v, x
|
398 |
+
|
399 |
+
q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
400 |
+
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
401 |
+
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
402 |
+
|
403 |
+
if q.device.type == 'mps':
|
404 |
+
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
|
405 |
+
|
406 |
+
dtype = q.dtype
|
407 |
+
if shared.opts.upcast_attn:
|
408 |
+
q, k = q.float(), k.float()
|
409 |
+
|
410 |
+
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
411 |
+
|
412 |
+
x = x.to(dtype)
|
413 |
+
|
414 |
+
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
415 |
+
|
416 |
+
out_proj, dropout = self.to_out
|
417 |
+
x = out_proj(x)
|
418 |
+
x = dropout(x)
|
419 |
+
|
420 |
+
return x
|
421 |
+
|
422 |
+
|
423 |
+
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
424 |
+
bytes_per_token = torch.finfo(q.dtype).bits//8
|
425 |
+
batch_x_heads, q_tokens, _ = q.shape
|
426 |
+
_, k_tokens, _ = k.shape
|
427 |
+
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
428 |
+
|
429 |
+
if chunk_threshold is None:
|
430 |
+
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
431 |
+
elif chunk_threshold == 0:
|
432 |
+
chunk_threshold_bytes = None
|
433 |
+
else:
|
434 |
+
chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
|
435 |
+
|
436 |
+
if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
|
437 |
+
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
|
438 |
+
elif kv_chunk_size_min == 0:
|
439 |
+
kv_chunk_size_min = None
|
440 |
+
|
441 |
+
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
442 |
+
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
443 |
+
# i.e. send it down the unchunked fast-path
|
444 |
+
kv_chunk_size = k_tokens
|
445 |
+
|
446 |
+
with devices.without_autocast(disable=q.dtype == v.dtype):
|
447 |
+
return sub_quadratic_attention.efficient_dot_product_attention(
|
448 |
+
q,
|
449 |
+
k,
|
450 |
+
v,
|
451 |
+
query_chunk_size=q_chunk_size,
|
452 |
+
kv_chunk_size=kv_chunk_size,
|
453 |
+
kv_chunk_size_min = kv_chunk_size_min,
|
454 |
+
use_checkpoint=use_checkpoint,
|
455 |
+
)
|
456 |
+
|
457 |
+
|
458 |
+
def get_xformers_flash_attention_op(q, k, v):
|
459 |
+
if not shared.cmd_opts.xformers_flash_attention:
|
460 |
+
return None
|
461 |
+
|
462 |
+
try:
|
463 |
+
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
|
464 |
+
fw, bw = flash_attention_op
|
465 |
+
if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
|
466 |
+
return flash_attention_op
|
467 |
+
except Exception as e:
|
468 |
+
errors.display_once(e, "enabling flash attention")
|
469 |
+
|
470 |
+
return None
|
471 |
+
|
472 |
+
|
473 |
+
def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
|
474 |
+
h = self.heads
|
475 |
+
q_in = self.to_q(x)
|
476 |
+
context = default(context, x)
|
477 |
+
|
478 |
+
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
479 |
+
k_in = self.to_k(context_k)
|
480 |
+
v_in = self.to_v(context_v)
|
481 |
+
|
482 |
+
q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
|
483 |
+
del q_in, k_in, v_in
|
484 |
+
|
485 |
+
dtype = q.dtype
|
486 |
+
if shared.opts.upcast_attn:
|
487 |
+
q, k, v = q.float(), k.float(), v.float()
|
488 |
+
|
489 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
|
490 |
+
|
491 |
+
out = out.to(dtype)
|
492 |
+
|
493 |
+
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
494 |
+
return self.to_out(out)
|
495 |
+
|
496 |
+
|
497 |
+
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
|
498 |
+
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
|
499 |
+
def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
|
500 |
+
batch_size, sequence_length, inner_dim = x.shape
|
501 |
+
|
502 |
+
if mask is not None:
|
503 |
+
mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
|
504 |
+
mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
|
505 |
+
|
506 |
+
h = self.heads
|
507 |
+
q_in = self.to_q(x)
|
508 |
+
context = default(context, x)
|
509 |
+
|
510 |
+
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
511 |
+
k_in = self.to_k(context_k)
|
512 |
+
v_in = self.to_v(context_v)
|
513 |
+
|
514 |
+
head_dim = inner_dim // h
|
515 |
+
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
516 |
+
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
517 |
+
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
518 |
+
|
519 |
+
del q_in, k_in, v_in
|
520 |
+
|
521 |
+
dtype = q.dtype
|
522 |
+
if shared.opts.upcast_attn:
|
523 |
+
q, k, v = q.float(), k.float(), v.float()
|
524 |
+
|
525 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
526 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
527 |
+
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
528 |
+
)
|
529 |
+
|
530 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
|
531 |
+
hidden_states = hidden_states.to(dtype)
|
532 |
+
|
533 |
+
# linear proj
|
534 |
+
hidden_states = self.to_out[0](hidden_states)
|
535 |
+
# dropout
|
536 |
+
hidden_states = self.to_out[1](hidden_states)
|
537 |
+
return hidden_states
|
538 |
+
|
539 |
+
|
540 |
+
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
|
541 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
542 |
+
return scaled_dot_product_attention_forward(self, x, context, mask)
|
543 |
+
|
544 |
+
|
545 |
+
def cross_attention_attnblock_forward(self, x):
|
546 |
+
h_ = x
|
547 |
+
h_ = self.norm(h_)
|
548 |
+
q1 = self.q(h_)
|
549 |
+
k1 = self.k(h_)
|
550 |
+
v = self.v(h_)
|
551 |
+
|
552 |
+
# compute attention
|
553 |
+
b, c, h, w = q1.shape
|
554 |
+
|
555 |
+
q2 = q1.reshape(b, c, h*w)
|
556 |
+
del q1
|
557 |
+
|
558 |
+
q = q2.permute(0, 2, 1) # b,hw,c
|
559 |
+
del q2
|
560 |
+
|
561 |
+
k = k1.reshape(b, c, h*w) # b,c,hw
|
562 |
+
del k1
|
563 |
+
|
564 |
+
h_ = torch.zeros_like(k, device=q.device)
|
565 |
+
|
566 |
+
mem_free_total = get_available_vram()
|
567 |
+
|
568 |
+
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
569 |
+
mem_required = tensor_size * 2.5
|
570 |
+
steps = 1
|
571 |
+
|
572 |
+
if mem_required > mem_free_total:
|
573 |
+
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
574 |
+
|
575 |
+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
576 |
+
for i in range(0, q.shape[1], slice_size):
|
577 |
+
end = i + slice_size
|
578 |
+
|
579 |
+
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
580 |
+
w2 = w1 * (int(c)**(-0.5))
|
581 |
+
del w1
|
582 |
+
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
583 |
+
del w2
|
584 |
+
|
585 |
+
# attend to values
|
586 |
+
v1 = v.reshape(b, c, h*w)
|
587 |
+
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
588 |
+
del w3
|
589 |
+
|
590 |
+
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
591 |
+
del v1, w4
|
592 |
+
|
593 |
+
h2 = h_.reshape(b, c, h, w)
|
594 |
+
del h_
|
595 |
+
|
596 |
+
h3 = self.proj_out(h2)
|
597 |
+
del h2
|
598 |
+
|
599 |
+
h3 += x
|
600 |
+
|
601 |
+
return h3
|
602 |
+
|
603 |
+
|
604 |
+
def xformers_attnblock_forward(self, x):
|
605 |
+
try:
|
606 |
+
h_ = x
|
607 |
+
h_ = self.norm(h_)
|
608 |
+
q = self.q(h_)
|
609 |
+
k = self.k(h_)
|
610 |
+
v = self.v(h_)
|
611 |
+
b, c, h, w = q.shape
|
612 |
+
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
613 |
+
dtype = q.dtype
|
614 |
+
if shared.opts.upcast_attn:
|
615 |
+
q, k = q.float(), k.float()
|
616 |
+
q = q.contiguous()
|
617 |
+
k = k.contiguous()
|
618 |
+
v = v.contiguous()
|
619 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
|
620 |
+
out = out.to(dtype)
|
621 |
+
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
622 |
+
out = self.proj_out(out)
|
623 |
+
return x + out
|
624 |
+
except NotImplementedError:
|
625 |
+
return cross_attention_attnblock_forward(self, x)
|
626 |
+
|
627 |
+
|
628 |
+
def sdp_attnblock_forward(self, x):
|
629 |
+
h_ = x
|
630 |
+
h_ = self.norm(h_)
|
631 |
+
q = self.q(h_)
|
632 |
+
k = self.k(h_)
|
633 |
+
v = self.v(h_)
|
634 |
+
b, c, h, w = q.shape
|
635 |
+
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
636 |
+
dtype = q.dtype
|
637 |
+
if shared.opts.upcast_attn:
|
638 |
+
q, k, v = q.float(), k.float(), v.float()
|
639 |
+
q = q.contiguous()
|
640 |
+
k = k.contiguous()
|
641 |
+
v = v.contiguous()
|
642 |
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
|
643 |
+
out = out.to(dtype)
|
644 |
+
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
645 |
+
out = self.proj_out(out)
|
646 |
+
return x + out
|
647 |
+
|
648 |
+
|
649 |
+
def sdp_no_mem_attnblock_forward(self, x):
|
650 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
651 |
+
return sdp_attnblock_forward(self, x)
|
652 |
+
|
653 |
+
|
654 |
+
def sub_quad_attnblock_forward(self, x):
|
655 |
+
h_ = x
|
656 |
+
h_ = self.norm(h_)
|
657 |
+
q = self.q(h_)
|
658 |
+
k = self.k(h_)
|
659 |
+
v = self.v(h_)
|
660 |
+
b, c, h, w = q.shape
|
661 |
+
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
662 |
+
q = q.contiguous()
|
663 |
+
k = k.contiguous()
|
664 |
+
v = v.contiguous()
|
665 |
+
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
666 |
+
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
667 |
+
out = self.proj_out(out)
|
668 |
+
return x + out
|
modules/sd_hijack_unet.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from packaging import version
|
3 |
+
|
4 |
+
from modules import devices
|
5 |
+
from modules.sd_hijack_utils import CondFunc
|
6 |
+
|
7 |
+
|
8 |
+
class TorchHijackForUnet:
|
9 |
+
"""
|
10 |
+
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
11 |
+
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __getattr__(self, item):
|
15 |
+
if item == 'cat':
|
16 |
+
return self.cat
|
17 |
+
|
18 |
+
if hasattr(torch, item):
|
19 |
+
return getattr(torch, item)
|
20 |
+
|
21 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
22 |
+
|
23 |
+
def cat(self, tensors, *args, **kwargs):
|
24 |
+
if len(tensors) == 2:
|
25 |
+
a, b = tensors
|
26 |
+
if a.shape[-2:] != b.shape[-2:]:
|
27 |
+
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
28 |
+
|
29 |
+
tensors = (a, b)
|
30 |
+
|
31 |
+
return torch.cat(tensors, *args, **kwargs)
|
32 |
+
|
33 |
+
|
34 |
+
th = TorchHijackForUnet()
|
35 |
+
|
36 |
+
|
37 |
+
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
38 |
+
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
39 |
+
|
40 |
+
if isinstance(cond, dict):
|
41 |
+
for y in cond.keys():
|
42 |
+
if isinstance(cond[y], list):
|
43 |
+
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
44 |
+
else:
|
45 |
+
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
46 |
+
|
47 |
+
with devices.autocast():
|
48 |
+
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
49 |
+
|
50 |
+
|
51 |
+
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
52 |
+
def __init__(self, *args, **kwargs):
|
53 |
+
torch.nn.GELU.__init__(self, *args, **kwargs)
|
54 |
+
def forward(self, x):
|
55 |
+
if devices.unet_needs_upcast:
|
56 |
+
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
57 |
+
else:
|
58 |
+
return torch.nn.GELU.forward(self, x)
|
59 |
+
|
60 |
+
|
61 |
+
ddpm_edit_hijack = None
|
62 |
+
def hijack_ddpm_edit():
|
63 |
+
global ddpm_edit_hijack
|
64 |
+
if not ddpm_edit_hijack:
|
65 |
+
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
66 |
+
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
67 |
+
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
68 |
+
|
69 |
+
|
70 |
+
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
71 |
+
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
72 |
+
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
73 |
+
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
74 |
+
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
75 |
+
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
76 |
+
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
77 |
+
|
78 |
+
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
79 |
+
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
80 |
+
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
81 |
+
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
82 |
+
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
83 |
+
|
84 |
+
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
|
85 |
+
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
modules/sd_hijack_utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
class CondFunc:
|
4 |
+
def __new__(cls, orig_func, sub_func, cond_func):
|
5 |
+
self = super(CondFunc, cls).__new__(cls)
|
6 |
+
if isinstance(orig_func, str):
|
7 |
+
func_path = orig_func.split('.')
|
8 |
+
for i in range(len(func_path)-1, -1, -1):
|
9 |
+
try:
|
10 |
+
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
11 |
+
break
|
12 |
+
except ImportError:
|
13 |
+
pass
|
14 |
+
for attr_name in func_path[i:-1]:
|
15 |
+
resolved_obj = getattr(resolved_obj, attr_name)
|
16 |
+
orig_func = getattr(resolved_obj, func_path[-1])
|
17 |
+
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
18 |
+
self.__init__(orig_func, sub_func, cond_func)
|
19 |
+
return lambda *args, **kwargs: self(*args, **kwargs)
|
20 |
+
def __init__(self, orig_func, sub_func, cond_func):
|
21 |
+
self.__orig_func = orig_func
|
22 |
+
self.__sub_func = sub_func
|
23 |
+
self.__cond_func = cond_func
|
24 |
+
def __call__(self, *args, **kwargs):
|
25 |
+
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
|
26 |
+
return self.__sub_func(self.__orig_func, *args, **kwargs)
|
27 |
+
else:
|
28 |
+
return self.__orig_func(*args, **kwargs)
|
modules/sd_hijack_xlmr.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from modules import sd_hijack_clip, devices
|
4 |
+
|
5 |
+
|
6 |
+
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
7 |
+
def __init__(self, wrapped, hijack):
|
8 |
+
super().__init__(wrapped, hijack)
|
9 |
+
|
10 |
+
self.id_start = wrapped.config.bos_token_id
|
11 |
+
self.id_end = wrapped.config.eos_token_id
|
12 |
+
self.id_pad = wrapped.config.pad_token_id
|
13 |
+
|
14 |
+
self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
|
15 |
+
|
16 |
+
def encode_with_transformers(self, tokens):
|
17 |
+
# there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
|
18 |
+
# trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
|
19 |
+
# layer to work with - you have to use the last
|
20 |
+
|
21 |
+
attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
|
22 |
+
features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
|
23 |
+
z = features['projection_state']
|
24 |
+
|
25 |
+
return z
|
26 |
+
|
27 |
+
def encode_embedding_init_text(self, init_text, nvpt):
|
28 |
+
embedding_layer = self.wrapped.roberta.embeddings
|
29 |
+
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
30 |
+
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
31 |
+
|
32 |
+
return embedded
|
modules/sd_models.py
ADDED
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import os.path
|
3 |
+
import sys
|
4 |
+
import gc
|
5 |
+
import threading
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import re
|
9 |
+
import safetensors.torch
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from os import mkdir
|
12 |
+
from urllib import request
|
13 |
+
import ldm.modules.midas as midas
|
14 |
+
|
15 |
+
from ldm.util import instantiate_from_config
|
16 |
+
|
17 |
+
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
|
18 |
+
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
19 |
+
from modules.timer import Timer
|
20 |
+
import tomesd
|
21 |
+
|
22 |
+
model_dir = "Stable-diffusion"
|
23 |
+
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
24 |
+
|
25 |
+
checkpoints_list = {}
|
26 |
+
checkpoint_aliases = {}
|
27 |
+
checkpoint_alisases = checkpoint_aliases # for compatibility with old name
|
28 |
+
checkpoints_loaded = collections.OrderedDict()
|
29 |
+
|
30 |
+
|
31 |
+
class CheckpointInfo:
|
32 |
+
def __init__(self, filename):
|
33 |
+
self.filename = filename
|
34 |
+
abspath = os.path.abspath(filename)
|
35 |
+
|
36 |
+
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
37 |
+
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
38 |
+
elif abspath.startswith(model_path):
|
39 |
+
name = abspath.replace(model_path, '')
|
40 |
+
else:
|
41 |
+
name = os.path.basename(filename)
|
42 |
+
|
43 |
+
if name.startswith("\\") or name.startswith("/"):
|
44 |
+
name = name[1:]
|
45 |
+
|
46 |
+
self.name = name
|
47 |
+
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
48 |
+
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
49 |
+
self.hash = model_hash(filename)
|
50 |
+
|
51 |
+
self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
|
52 |
+
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
53 |
+
|
54 |
+
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
55 |
+
|
56 |
+
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
57 |
+
|
58 |
+
self.metadata = {}
|
59 |
+
|
60 |
+
_, ext = os.path.splitext(self.filename)
|
61 |
+
if ext.lower() == ".safetensors":
|
62 |
+
try:
|
63 |
+
self.metadata = read_metadata_from_safetensors(filename)
|
64 |
+
except Exception as e:
|
65 |
+
errors.display(e, f"reading checkpoint metadata: {filename}")
|
66 |
+
|
67 |
+
def register(self):
|
68 |
+
checkpoints_list[self.title] = self
|
69 |
+
for id in self.ids:
|
70 |
+
checkpoint_aliases[id] = self
|
71 |
+
|
72 |
+
def calculate_shorthash(self):
|
73 |
+
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
|
74 |
+
if self.sha256 is None:
|
75 |
+
return
|
76 |
+
|
77 |
+
self.shorthash = self.sha256[0:10]
|
78 |
+
|
79 |
+
if self.shorthash not in self.ids:
|
80 |
+
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
81 |
+
|
82 |
+
checkpoints_list.pop(self.title)
|
83 |
+
self.title = f'{self.name} [{self.shorthash}]'
|
84 |
+
self.register()
|
85 |
+
|
86 |
+
return self.shorthash
|
87 |
+
|
88 |
+
|
89 |
+
try:
|
90 |
+
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
91 |
+
from transformers import logging, CLIPModel # noqa: F401
|
92 |
+
|
93 |
+
logging.set_verbosity_error()
|
94 |
+
except Exception:
|
95 |
+
pass
|
96 |
+
|
97 |
+
|
98 |
+
def setup_model():
|
99 |
+
os.makedirs(model_path, exist_ok=True)
|
100 |
+
|
101 |
+
enable_midas_autodownload()
|
102 |
+
|
103 |
+
|
104 |
+
def checkpoint_tiles():
|
105 |
+
def convert(name):
|
106 |
+
return int(name) if name.isdigit() else name.lower()
|
107 |
+
|
108 |
+
def alphanumeric_key(key):
|
109 |
+
return [convert(c) for c in re.split('([0-9]+)', key)]
|
110 |
+
|
111 |
+
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
112 |
+
|
113 |
+
|
114 |
+
def list_models():
|
115 |
+
checkpoints_list.clear()
|
116 |
+
checkpoint_aliases.clear()
|
117 |
+
|
118 |
+
cmd_ckpt = shared.cmd_opts.ckpt
|
119 |
+
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
|
120 |
+
model_url = None
|
121 |
+
else:
|
122 |
+
model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
|
123 |
+
|
124 |
+
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
|
125 |
+
|
126 |
+
if os.path.exists(cmd_ckpt):
|
127 |
+
checkpoint_info = CheckpointInfo(cmd_ckpt)
|
128 |
+
checkpoint_info.register()
|
129 |
+
|
130 |
+
shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
|
131 |
+
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
132 |
+
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
133 |
+
|
134 |
+
for filename in sorted(model_list, key=str.lower):
|
135 |
+
checkpoint_info = CheckpointInfo(filename)
|
136 |
+
checkpoint_info.register()
|
137 |
+
|
138 |
+
|
139 |
+
def get_closet_checkpoint_match(search_string):
|
140 |
+
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
141 |
+
if checkpoint_info is not None:
|
142 |
+
return checkpoint_info
|
143 |
+
|
144 |
+
found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
|
145 |
+
if found:
|
146 |
+
return found[0]
|
147 |
+
|
148 |
+
return None
|
149 |
+
|
150 |
+
|
151 |
+
def model_hash(filename):
|
152 |
+
"""old hash that only looks at a small part of the file and is prone to collisions"""
|
153 |
+
|
154 |
+
try:
|
155 |
+
with open(filename, "rb") as file:
|
156 |
+
import hashlib
|
157 |
+
m = hashlib.sha256()
|
158 |
+
|
159 |
+
file.seek(0x100000)
|
160 |
+
m.update(file.read(0x10000))
|
161 |
+
return m.hexdigest()[0:8]
|
162 |
+
except FileNotFoundError:
|
163 |
+
return 'NOFILE'
|
164 |
+
|
165 |
+
|
166 |
+
def select_checkpoint():
|
167 |
+
"""Raises `FileNotFoundError` if no checkpoints are found."""
|
168 |
+
model_checkpoint = shared.opts.sd_model_checkpoint
|
169 |
+
|
170 |
+
checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
|
171 |
+
if checkpoint_info is not None:
|
172 |
+
return checkpoint_info
|
173 |
+
|
174 |
+
if len(checkpoints_list) == 0:
|
175 |
+
error_message = "No checkpoints found. When searching for checkpoints, looked at:"
|
176 |
+
if shared.cmd_opts.ckpt is not None:
|
177 |
+
error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
|
178 |
+
error_message += f"\n - directory {model_path}"
|
179 |
+
if shared.cmd_opts.ckpt_dir is not None:
|
180 |
+
error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
|
181 |
+
error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
|
182 |
+
raise FileNotFoundError(error_message)
|
183 |
+
|
184 |
+
checkpoint_info = next(iter(checkpoints_list.values()))
|
185 |
+
if model_checkpoint is not None:
|
186 |
+
print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
|
187 |
+
|
188 |
+
return checkpoint_info
|
189 |
+
|
190 |
+
|
191 |
+
checkpoint_dict_replacements = {
|
192 |
+
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
193 |
+
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
194 |
+
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
195 |
+
}
|
196 |
+
|
197 |
+
|
198 |
+
def transform_checkpoint_dict_key(k):
|
199 |
+
for text, replacement in checkpoint_dict_replacements.items():
|
200 |
+
if k.startswith(text):
|
201 |
+
k = replacement + k[len(text):]
|
202 |
+
|
203 |
+
return k
|
204 |
+
|
205 |
+
|
206 |
+
def get_state_dict_from_checkpoint(pl_sd):
|
207 |
+
pl_sd = pl_sd.pop("state_dict", pl_sd)
|
208 |
+
pl_sd.pop("state_dict", None)
|
209 |
+
|
210 |
+
sd = {}
|
211 |
+
for k, v in pl_sd.items():
|
212 |
+
new_key = transform_checkpoint_dict_key(k)
|
213 |
+
|
214 |
+
if new_key is not None:
|
215 |
+
sd[new_key] = v
|
216 |
+
|
217 |
+
pl_sd.clear()
|
218 |
+
pl_sd.update(sd)
|
219 |
+
|
220 |
+
return pl_sd
|
221 |
+
|
222 |
+
|
223 |
+
def read_metadata_from_safetensors(filename):
|
224 |
+
import json
|
225 |
+
|
226 |
+
with open(filename, mode="rb") as file:
|
227 |
+
metadata_len = file.read(8)
|
228 |
+
metadata_len = int.from_bytes(metadata_len, "little")
|
229 |
+
json_start = file.read(2)
|
230 |
+
|
231 |
+
assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
|
232 |
+
json_data = json_start + file.read(metadata_len-2)
|
233 |
+
json_obj = json.loads(json_data)
|
234 |
+
|
235 |
+
res = {}
|
236 |
+
for k, v in json_obj.get("__metadata__", {}).items():
|
237 |
+
res[k] = v
|
238 |
+
if isinstance(v, str) and v[0:1] == '{':
|
239 |
+
try:
|
240 |
+
res[k] = json.loads(v)
|
241 |
+
except Exception:
|
242 |
+
pass
|
243 |
+
|
244 |
+
return res
|
245 |
+
|
246 |
+
|
247 |
+
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
248 |
+
_, extension = os.path.splitext(checkpoint_file)
|
249 |
+
if extension.lower() == ".safetensors":
|
250 |
+
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
251 |
+
|
252 |
+
if not shared.opts.disable_mmap_load_safetensors:
|
253 |
+
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
254 |
+
else:
|
255 |
+
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
|
256 |
+
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
|
257 |
+
else:
|
258 |
+
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
259 |
+
|
260 |
+
if print_global_state and "global_step" in pl_sd:
|
261 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
262 |
+
|
263 |
+
sd = get_state_dict_from_checkpoint(pl_sd)
|
264 |
+
return sd
|
265 |
+
|
266 |
+
|
267 |
+
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
268 |
+
sd_model_hash = checkpoint_info.calculate_shorthash()
|
269 |
+
timer.record("calculate hash")
|
270 |
+
|
271 |
+
if checkpoint_info in checkpoints_loaded:
|
272 |
+
# use checkpoint cache
|
273 |
+
print(f"Loading weights [{sd_model_hash}] from cache")
|
274 |
+
return checkpoints_loaded[checkpoint_info]
|
275 |
+
|
276 |
+
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
277 |
+
res = read_state_dict(checkpoint_info.filename)
|
278 |
+
timer.record("load weights from disk")
|
279 |
+
|
280 |
+
return res
|
281 |
+
|
282 |
+
|
283 |
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
284 |
+
sd_model_hash = checkpoint_info.calculate_shorthash()
|
285 |
+
timer.record("calculate hash")
|
286 |
+
|
287 |
+
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
288 |
+
|
289 |
+
if state_dict is None:
|
290 |
+
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
291 |
+
|
292 |
+
model.is_sdxl = hasattr(model, 'conditioner')
|
293 |
+
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
294 |
+
model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
295 |
+
|
296 |
+
if model.is_sdxl:
|
297 |
+
sd_models_xl.extend_sdxl(model)
|
298 |
+
|
299 |
+
model.load_state_dict(state_dict, strict=False)
|
300 |
+
del state_dict
|
301 |
+
timer.record("apply weights to model")
|
302 |
+
|
303 |
+
if shared.opts.sd_checkpoint_cache > 0:
|
304 |
+
# cache newly loaded model
|
305 |
+
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
306 |
+
|
307 |
+
if shared.cmd_opts.opt_channelslast:
|
308 |
+
model.to(memory_format=torch.channels_last)
|
309 |
+
timer.record("apply channels_last")
|
310 |
+
|
311 |
+
if not shared.cmd_opts.no_half:
|
312 |
+
vae = model.first_stage_model
|
313 |
+
depth_model = getattr(model, 'depth_model', None)
|
314 |
+
|
315 |
+
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
316 |
+
if shared.cmd_opts.no_half_vae:
|
317 |
+
model.first_stage_model = None
|
318 |
+
# with --upcast-sampling, don't convert the depth model weights to float16
|
319 |
+
if shared.cmd_opts.upcast_sampling and depth_model:
|
320 |
+
model.depth_model = None
|
321 |
+
|
322 |
+
model.half()
|
323 |
+
model.first_stage_model = vae
|
324 |
+
if depth_model:
|
325 |
+
model.depth_model = depth_model
|
326 |
+
|
327 |
+
timer.record("apply half()")
|
328 |
+
|
329 |
+
devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
|
330 |
+
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
331 |
+
|
332 |
+
model.first_stage_model.to(devices.dtype_vae)
|
333 |
+
timer.record("apply dtype to VAE")
|
334 |
+
|
335 |
+
# clean up cache if limit is reached
|
336 |
+
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
337 |
+
checkpoints_loaded.popitem(last=False)
|
338 |
+
|
339 |
+
model.sd_model_hash = sd_model_hash
|
340 |
+
model.sd_model_checkpoint = checkpoint_info.filename
|
341 |
+
model.sd_checkpoint_info = checkpoint_info
|
342 |
+
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
343 |
+
|
344 |
+
if hasattr(model, 'logvar'):
|
345 |
+
model.logvar = model.logvar.to(devices.device) # fix for training
|
346 |
+
|
347 |
+
sd_vae.delete_base_vae()
|
348 |
+
sd_vae.clear_loaded_vae()
|
349 |
+
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
350 |
+
sd_vae.load_vae(model, vae_file, vae_source)
|
351 |
+
timer.record("load VAE")
|
352 |
+
|
353 |
+
|
354 |
+
def enable_midas_autodownload():
|
355 |
+
"""
|
356 |
+
Gives the ldm.modules.midas.api.load_model function automatic downloading.
|
357 |
+
|
358 |
+
When the 512-depth-ema model, and other future models like it, is loaded,
|
359 |
+
it calls midas.api.load_model to load the associated midas depth model.
|
360 |
+
This function applies a wrapper to download the model to the correct
|
361 |
+
location automatically.
|
362 |
+
"""
|
363 |
+
|
364 |
+
midas_path = os.path.join(paths.models_path, 'midas')
|
365 |
+
|
366 |
+
# stable-diffusion-stability-ai hard-codes the midas model path to
|
367 |
+
# a location that differs from where other scripts using this model look.
|
368 |
+
# HACK: Overriding the path here.
|
369 |
+
for k, v in midas.api.ISL_PATHS.items():
|
370 |
+
file_name = os.path.basename(v)
|
371 |
+
midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
|
372 |
+
|
373 |
+
midas_urls = {
|
374 |
+
"dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
|
375 |
+
"dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
|
376 |
+
"midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
|
377 |
+
"midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
|
378 |
+
}
|
379 |
+
|
380 |
+
midas.api.load_model_inner = midas.api.load_model
|
381 |
+
|
382 |
+
def load_model_wrapper(model_type):
|
383 |
+
path = midas.api.ISL_PATHS[model_type]
|
384 |
+
if not os.path.exists(path):
|
385 |
+
if not os.path.exists(midas_path):
|
386 |
+
mkdir(midas_path)
|
387 |
+
|
388 |
+
print(f"Downloading midas model weights for {model_type} to {path}")
|
389 |
+
request.urlretrieve(midas_urls[model_type], path)
|
390 |
+
print(f"{model_type} downloaded")
|
391 |
+
|
392 |
+
return midas.api.load_model_inner(model_type)
|
393 |
+
|
394 |
+
midas.api.load_model = load_model_wrapper
|
395 |
+
|
396 |
+
|
397 |
+
def repair_config(sd_config):
|
398 |
+
|
399 |
+
if not hasattr(sd_config.model.params, "use_ema"):
|
400 |
+
sd_config.model.params.use_ema = False
|
401 |
+
|
402 |
+
if hasattr(sd_config.model.params, 'unet_config'):
|
403 |
+
if shared.cmd_opts.no_half:
|
404 |
+
sd_config.model.params.unet_config.params.use_fp16 = False
|
405 |
+
elif shared.cmd_opts.upcast_sampling:
|
406 |
+
sd_config.model.params.unet_config.params.use_fp16 = True
|
407 |
+
|
408 |
+
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
409 |
+
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
410 |
+
|
411 |
+
# For UnCLIP-L, override the hardcoded karlo directory
|
412 |
+
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
|
413 |
+
karlo_path = os.path.join(paths.models_path, 'karlo')
|
414 |
+
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
415 |
+
|
416 |
+
|
417 |
+
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
418 |
+
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
419 |
+
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
420 |
+
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
421 |
+
|
422 |
+
|
423 |
+
class SdModelData:
|
424 |
+
def __init__(self):
|
425 |
+
self.sd_model = None
|
426 |
+
self.was_loaded_at_least_once = False
|
427 |
+
self.lock = threading.Lock()
|
428 |
+
|
429 |
+
def get_sd_model(self):
|
430 |
+
if self.was_loaded_at_least_once:
|
431 |
+
return self.sd_model
|
432 |
+
|
433 |
+
if self.sd_model is None:
|
434 |
+
with self.lock:
|
435 |
+
if self.sd_model is not None or self.was_loaded_at_least_once:
|
436 |
+
return self.sd_model
|
437 |
+
|
438 |
+
try:
|
439 |
+
load_model()
|
440 |
+
except Exception as e:
|
441 |
+
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
442 |
+
print("", file=sys.stderr)
|
443 |
+
print("Stable diffusion model failed to load", file=sys.stderr)
|
444 |
+
self.sd_model = None
|
445 |
+
|
446 |
+
return self.sd_model
|
447 |
+
|
448 |
+
def set_sd_model(self, v):
|
449 |
+
self.sd_model = v
|
450 |
+
|
451 |
+
|
452 |
+
model_data = SdModelData()
|
453 |
+
|
454 |
+
|
455 |
+
def get_empty_cond(sd_model):
|
456 |
+
if hasattr(sd_model, 'conditioner'):
|
457 |
+
d = sd_model.get_learned_conditioning([""])
|
458 |
+
return d['crossattn']
|
459 |
+
else:
|
460 |
+
return sd_model.cond_stage_model([""])
|
461 |
+
|
462 |
+
|
463 |
+
|
464 |
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
465 |
+
from modules import lowvram, sd_hijack
|
466 |
+
checkpoint_info = checkpoint_info or select_checkpoint()
|
467 |
+
|
468 |
+
if model_data.sd_model:
|
469 |
+
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
470 |
+
model_data.sd_model = None
|
471 |
+
gc.collect()
|
472 |
+
devices.torch_gc()
|
473 |
+
|
474 |
+
do_inpainting_hijack()
|
475 |
+
|
476 |
+
timer = Timer()
|
477 |
+
|
478 |
+
if already_loaded_state_dict is not None:
|
479 |
+
state_dict = already_loaded_state_dict
|
480 |
+
else:
|
481 |
+
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
482 |
+
|
483 |
+
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
484 |
+
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
|
485 |
+
|
486 |
+
timer.record("find config")
|
487 |
+
|
488 |
+
sd_config = OmegaConf.load(checkpoint_config)
|
489 |
+
repair_config(sd_config)
|
490 |
+
|
491 |
+
timer.record("load config")
|
492 |
+
|
493 |
+
print(f"Creating model from config: {checkpoint_config}")
|
494 |
+
|
495 |
+
sd_model = None
|
496 |
+
try:
|
497 |
+
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
498 |
+
sd_model = instantiate_from_config(sd_config.model)
|
499 |
+
except Exception:
|
500 |
+
pass
|
501 |
+
|
502 |
+
if sd_model is None:
|
503 |
+
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
504 |
+
sd_model = instantiate_from_config(sd_config.model)
|
505 |
+
|
506 |
+
sd_model.used_config = checkpoint_config
|
507 |
+
|
508 |
+
timer.record("create model")
|
509 |
+
|
510 |
+
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
511 |
+
|
512 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
513 |
+
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
514 |
+
else:
|
515 |
+
sd_model.to(shared.device)
|
516 |
+
|
517 |
+
timer.record("move model to device")
|
518 |
+
|
519 |
+
sd_hijack.model_hijack.hijack(sd_model)
|
520 |
+
|
521 |
+
timer.record("hijack")
|
522 |
+
|
523 |
+
sd_model.eval()
|
524 |
+
model_data.sd_model = sd_model
|
525 |
+
model_data.was_loaded_at_least_once = True
|
526 |
+
|
527 |
+
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
528 |
+
|
529 |
+
timer.record("load textual inversion embeddings")
|
530 |
+
|
531 |
+
script_callbacks.model_loaded_callback(sd_model)
|
532 |
+
|
533 |
+
timer.record("scripts callbacks")
|
534 |
+
|
535 |
+
with devices.autocast(), torch.no_grad():
|
536 |
+
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
|
537 |
+
|
538 |
+
timer.record("calculate empty prompt")
|
539 |
+
|
540 |
+
print(f"Model loaded in {timer.summary()}.")
|
541 |
+
|
542 |
+
return sd_model
|
543 |
+
|
544 |
+
|
545 |
+
def reload_model_weights(sd_model=None, info=None):
|
546 |
+
from modules import lowvram, devices, sd_hijack
|
547 |
+
checkpoint_info = info or select_checkpoint()
|
548 |
+
|
549 |
+
if not sd_model:
|
550 |
+
sd_model = model_data.sd_model
|
551 |
+
|
552 |
+
if sd_model is None: # previous model load failed
|
553 |
+
current_checkpoint_info = None
|
554 |
+
else:
|
555 |
+
current_checkpoint_info = sd_model.sd_checkpoint_info
|
556 |
+
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
557 |
+
return
|
558 |
+
|
559 |
+
sd_unet.apply_unet("None")
|
560 |
+
|
561 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
562 |
+
lowvram.send_everything_to_cpu()
|
563 |
+
else:
|
564 |
+
sd_model.to(devices.cpu)
|
565 |
+
|
566 |
+
sd_hijack.model_hijack.undo_hijack(sd_model)
|
567 |
+
|
568 |
+
timer = Timer()
|
569 |
+
|
570 |
+
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
571 |
+
|
572 |
+
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
573 |
+
|
574 |
+
timer.record("find config")
|
575 |
+
|
576 |
+
if sd_model is None or checkpoint_config != sd_model.used_config:
|
577 |
+
del sd_model
|
578 |
+
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
579 |
+
return model_data.sd_model
|
580 |
+
|
581 |
+
try:
|
582 |
+
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
583 |
+
except Exception:
|
584 |
+
print("Failed to load checkpoint, restoring previous")
|
585 |
+
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
586 |
+
raise
|
587 |
+
finally:
|
588 |
+
sd_hijack.model_hijack.hijack(sd_model)
|
589 |
+
timer.record("hijack")
|
590 |
+
|
591 |
+
script_callbacks.model_loaded_callback(sd_model)
|
592 |
+
timer.record("script callbacks")
|
593 |
+
|
594 |
+
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
595 |
+
sd_model.to(devices.device)
|
596 |
+
timer.record("move model to device")
|
597 |
+
|
598 |
+
print(f"Weights loaded in {timer.summary()}.")
|
599 |
+
|
600 |
+
return sd_model
|
601 |
+
|
602 |
+
|
603 |
+
def unload_model_weights(sd_model=None, info=None):
|
604 |
+
from modules import devices, sd_hijack
|
605 |
+
timer = Timer()
|
606 |
+
|
607 |
+
if model_data.sd_model:
|
608 |
+
model_data.sd_model.to(devices.cpu)
|
609 |
+
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
610 |
+
model_data.sd_model = None
|
611 |
+
sd_model = None
|
612 |
+
gc.collect()
|
613 |
+
devices.torch_gc()
|
614 |
+
|
615 |
+
print(f"Unloaded weights {timer.summary()}.")
|
616 |
+
|
617 |
+
return sd_model
|
618 |
+
|
619 |
+
|
620 |
+
def apply_token_merging(sd_model, token_merging_ratio):
|
621 |
+
"""
|
622 |
+
Applies speed and memory optimizations from tomesd.
|
623 |
+
"""
|
624 |
+
|
625 |
+
current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
|
626 |
+
|
627 |
+
if current_token_merging_ratio == token_merging_ratio:
|
628 |
+
return
|
629 |
+
|
630 |
+
if current_token_merging_ratio > 0:
|
631 |
+
tomesd.remove_patch(sd_model)
|
632 |
+
|
633 |
+
if token_merging_ratio > 0:
|
634 |
+
tomesd.apply_patch(
|
635 |
+
sd_model,
|
636 |
+
ratio=token_merging_ratio,
|
637 |
+
use_rand=False, # can cause issues with some samplers
|
638 |
+
merge_attn=True,
|
639 |
+
merge_crossattn=False,
|
640 |
+
merge_mlp=False
|
641 |
+
)
|
642 |
+
|
643 |
+
sd_model.applied_token_merged_ratio = token_merging_ratio
|
modules/sd_models_config.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from modules import shared, paths, sd_disable_initialization
|
6 |
+
|
7 |
+
sd_configs_path = shared.sd_configs_path
|
8 |
+
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
9 |
+
sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
|
10 |
+
|
11 |
+
|
12 |
+
config_default = shared.sd_default_config
|
13 |
+
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
14 |
+
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
15 |
+
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
16 |
+
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
17 |
+
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
18 |
+
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
19 |
+
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
20 |
+
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
21 |
+
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
22 |
+
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
23 |
+
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
24 |
+
|
25 |
+
|
26 |
+
def is_using_v_parameterization_for_sd2(state_dict):
|
27 |
+
"""
|
28 |
+
Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
|
29 |
+
"""
|
30 |
+
|
31 |
+
import ldm.modules.diffusionmodules.openaimodel
|
32 |
+
from modules import devices
|
33 |
+
|
34 |
+
device = devices.cpu
|
35 |
+
|
36 |
+
with sd_disable_initialization.DisableInitialization():
|
37 |
+
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
38 |
+
use_checkpoint=True,
|
39 |
+
use_fp16=False,
|
40 |
+
image_size=32,
|
41 |
+
in_channels=4,
|
42 |
+
out_channels=4,
|
43 |
+
model_channels=320,
|
44 |
+
attention_resolutions=[4, 2, 1],
|
45 |
+
num_res_blocks=2,
|
46 |
+
channel_mult=[1, 2, 4, 4],
|
47 |
+
num_head_channels=64,
|
48 |
+
use_spatial_transformer=True,
|
49 |
+
use_linear_in_transformer=True,
|
50 |
+
transformer_depth=1,
|
51 |
+
context_dim=1024,
|
52 |
+
legacy=False
|
53 |
+
)
|
54 |
+
unet.eval()
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
58 |
+
unet.load_state_dict(unet_sd, strict=True)
|
59 |
+
unet.to(device=device, dtype=torch.float)
|
60 |
+
|
61 |
+
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
62 |
+
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
63 |
+
|
64 |
+
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
|
65 |
+
|
66 |
+
return out < -1
|
67 |
+
|
68 |
+
|
69 |
+
def guess_model_config_from_state_dict(sd, filename):
|
70 |
+
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
71 |
+
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
72 |
+
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
73 |
+
|
74 |
+
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
75 |
+
return config_sdxl
|
76 |
+
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
77 |
+
return config_sdxl_refiner
|
78 |
+
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
79 |
+
return config_depth_model
|
80 |
+
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
81 |
+
return config_unclip
|
82 |
+
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
|
83 |
+
return config_unopenclip
|
84 |
+
|
85 |
+
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
86 |
+
if diffusion_model_input.shape[1] == 9:
|
87 |
+
return config_sd2_inpainting
|
88 |
+
elif is_using_v_parameterization_for_sd2(sd):
|
89 |
+
return config_sd2v
|
90 |
+
else:
|
91 |
+
return config_sd2
|
92 |
+
|
93 |
+
if diffusion_model_input is not None:
|
94 |
+
if diffusion_model_input.shape[1] == 9:
|
95 |
+
return config_inpainting
|
96 |
+
if diffusion_model_input.shape[1] == 8:
|
97 |
+
return config_instruct_pix2pix
|
98 |
+
|
99 |
+
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
100 |
+
return config_alt_diffusion
|
101 |
+
|
102 |
+
return config_default
|
103 |
+
|
104 |
+
|
105 |
+
def find_checkpoint_config(state_dict, info):
|
106 |
+
if info is None:
|
107 |
+
return guess_model_config_from_state_dict(state_dict, "")
|
108 |
+
|
109 |
+
config = find_checkpoint_config_near_filename(info)
|
110 |
+
if config is not None:
|
111 |
+
return config
|
112 |
+
|
113 |
+
return guess_model_config_from_state_dict(state_dict, info.filename)
|
114 |
+
|
115 |
+
|
116 |
+
def find_checkpoint_config_near_filename(info):
|
117 |
+
if info is None:
|
118 |
+
return None
|
119 |
+
|
120 |
+
config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
121 |
+
if os.path.exists(config):
|
122 |
+
return config
|
123 |
+
|
124 |
+
return None
|
125 |
+
|
modules/sd_models_xl.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import sgm.models.diffusion
|
6 |
+
import sgm.modules.diffusionmodules.denoiser_scaling
|
7 |
+
import sgm.modules.diffusionmodules.discretizer
|
8 |
+
from modules import devices, shared, prompt_parser
|
9 |
+
|
10 |
+
|
11 |
+
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
12 |
+
for embedder in self.conditioner.embedders:
|
13 |
+
embedder.ucg_rate = 0.0
|
14 |
+
|
15 |
+
width = getattr(batch, 'width', 1024)
|
16 |
+
height = getattr(batch, 'height', 1024)
|
17 |
+
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
18 |
+
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
19 |
+
|
20 |
+
devices_args = dict(device=devices.device, dtype=devices.dtype)
|
21 |
+
|
22 |
+
sdxl_conds = {
|
23 |
+
"txt": batch,
|
24 |
+
"original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
25 |
+
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
26 |
+
"target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
27 |
+
"aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
28 |
+
}
|
29 |
+
|
30 |
+
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
31 |
+
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
32 |
+
|
33 |
+
return c
|
34 |
+
|
35 |
+
|
36 |
+
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
37 |
+
return self.model(x, t, cond)
|
38 |
+
|
39 |
+
|
40 |
+
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
45 |
+
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
46 |
+
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
47 |
+
|
48 |
+
|
49 |
+
def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
50 |
+
res = []
|
51 |
+
|
52 |
+
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
53 |
+
encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
54 |
+
res.append(encoded)
|
55 |
+
|
56 |
+
return torch.cat(res, dim=1)
|
57 |
+
|
58 |
+
|
59 |
+
def process_texts(self, texts):
|
60 |
+
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
61 |
+
return embedder.process_texts(texts)
|
62 |
+
|
63 |
+
|
64 |
+
def get_target_prompt_token_count(self, token_count):
|
65 |
+
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
66 |
+
return embedder.get_target_prompt_token_count(token_count)
|
67 |
+
|
68 |
+
|
69 |
+
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
70 |
+
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
71 |
+
sgm.modules.GeneralConditioner.process_texts = process_texts
|
72 |
+
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
73 |
+
|
74 |
+
|
75 |
+
def extend_sdxl(model):
|
76 |
+
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
77 |
+
|
78 |
+
dtype = next(model.model.diffusion_model.parameters()).dtype
|
79 |
+
model.model.diffusion_model.dtype = dtype
|
80 |
+
model.model.conditioning_key = 'crossattn'
|
81 |
+
model.cond_stage_key = 'txt'
|
82 |
+
# model.cond_stage_model will be set in sd_hijack
|
83 |
+
|
84 |
+
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
85 |
+
|
86 |
+
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
87 |
+
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
88 |
+
|
89 |
+
model.conditioner.wrapped = torch.nn.Module()
|
90 |
+
|
91 |
+
|
92 |
+
sgm.modules.attention.print = lambda *args: None
|
93 |
+
sgm.modules.diffusionmodules.model.print = lambda *args: None
|
94 |
+
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
|
95 |
+
sgm.modules.encoders.modules.print = lambda *args: None
|
96 |
+
|
97 |
+
# this gets the code to load the vanilla attention that we override
|
98 |
+
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
99 |
+
sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|
modules/sd_samplers.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
2 |
+
|
3 |
+
# imports for functions that previously were here and are used by other modules
|
4 |
+
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
5 |
+
|
6 |
+
all_samplers = [
|
7 |
+
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
8 |
+
*sd_samplers_compvis.samplers_data_compvis,
|
9 |
+
]
|
10 |
+
all_samplers_map = {x.name: x for x in all_samplers}
|
11 |
+
|
12 |
+
samplers = []
|
13 |
+
samplers_for_img2img = []
|
14 |
+
samplers_map = {}
|
15 |
+
|
16 |
+
|
17 |
+
def find_sampler_config(name):
|
18 |
+
if name is not None:
|
19 |
+
config = all_samplers_map.get(name, None)
|
20 |
+
else:
|
21 |
+
config = all_samplers[0]
|
22 |
+
|
23 |
+
return config
|
24 |
+
|
25 |
+
|
26 |
+
def create_sampler(name, model):
|
27 |
+
config = find_sampler_config(name)
|
28 |
+
|
29 |
+
assert config is not None, f'bad sampler name: {name}'
|
30 |
+
|
31 |
+
if model.is_sdxl and config.options.get("no_sdxl", False):
|
32 |
+
raise Exception(f"Sampler {config.name} is not supported for SDXL")
|
33 |
+
|
34 |
+
sampler = config.constructor(model)
|
35 |
+
sampler.config = config
|
36 |
+
|
37 |
+
return sampler
|
38 |
+
|
39 |
+
|
40 |
+
def set_samplers():
|
41 |
+
global samplers, samplers_for_img2img
|
42 |
+
|
43 |
+
hidden = set(shared.opts.hide_samplers)
|
44 |
+
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
|
45 |
+
|
46 |
+
samplers = [x for x in all_samplers if x.name not in hidden]
|
47 |
+
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
48 |
+
|
49 |
+
samplers_map.clear()
|
50 |
+
for sampler in all_samplers:
|
51 |
+
samplers_map[sampler.name.lower()] = sampler.name
|
52 |
+
for alias in sampler.aliases:
|
53 |
+
samplers_map[alias.lower()] = sampler.name
|
54 |
+
|
55 |
+
|
56 |
+
set_samplers()
|
modules/sd_samplers_common.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
|
6 |
+
|
7 |
+
from modules.shared import opts, state
|
8 |
+
import modules.shared as shared
|
9 |
+
|
10 |
+
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
11 |
+
|
12 |
+
|
13 |
+
def setup_img2img_steps(p, steps=None):
|
14 |
+
if opts.img2img_fix_steps or steps is not None:
|
15 |
+
requested_steps = (steps or p.steps)
|
16 |
+
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
17 |
+
t_enc = requested_steps - 1
|
18 |
+
else:
|
19 |
+
steps = p.steps
|
20 |
+
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
21 |
+
|
22 |
+
return steps, t_enc
|
23 |
+
|
24 |
+
|
25 |
+
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
26 |
+
|
27 |
+
|
28 |
+
def single_sample_to_image(sample, approximation=None):
|
29 |
+
if approximation is None:
|
30 |
+
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
31 |
+
|
32 |
+
if approximation == 2:
|
33 |
+
x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
|
34 |
+
elif approximation == 1:
|
35 |
+
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
|
36 |
+
elif approximation == 3:
|
37 |
+
x_sample = sample * 1.5
|
38 |
+
x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
39 |
+
else:
|
40 |
+
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
|
41 |
+
|
42 |
+
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
43 |
+
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
44 |
+
x_sample = x_sample.astype(np.uint8)
|
45 |
+
|
46 |
+
return Image.fromarray(x_sample)
|
47 |
+
|
48 |
+
|
49 |
+
def sample_to_image(samples, index=0, approximation=None):
|
50 |
+
return single_sample_to_image(samples[index], approximation)
|
51 |
+
|
52 |
+
|
53 |
+
def samples_to_image_grid(samples, approximation=None):
|
54 |
+
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
55 |
+
|
56 |
+
|
57 |
+
def store_latent(decoded):
|
58 |
+
state.current_latent = decoded
|
59 |
+
|
60 |
+
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
61 |
+
if not shared.parallel_processing_allowed:
|
62 |
+
shared.state.assign_current_image(sample_to_image(decoded))
|
63 |
+
|
64 |
+
|
65 |
+
def is_sampler_using_eta_noise_seed_delta(p):
|
66 |
+
"""returns whether sampler from config will use eta noise seed delta for image creation"""
|
67 |
+
|
68 |
+
sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
|
69 |
+
|
70 |
+
eta = p.eta
|
71 |
+
|
72 |
+
if eta is None and p.sampler is not None:
|
73 |
+
eta = p.sampler.eta
|
74 |
+
|
75 |
+
if eta is None and sampler_config is not None:
|
76 |
+
eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
|
77 |
+
|
78 |
+
if eta == 0:
|
79 |
+
return False
|
80 |
+
|
81 |
+
return sampler_config.options.get("uses_ensd", False)
|
82 |
+
|
83 |
+
|
84 |
+
class InterruptedException(BaseException):
|
85 |
+
pass
|
86 |
+
|
87 |
+
|
88 |
+
if opts.randn_source == "CPU":
|
89 |
+
import torchsde._brownian.brownian_interval
|
90 |
+
|
91 |
+
def torchsde_randn(size, dtype, device, seed):
|
92 |
+
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
93 |
+
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
94 |
+
|
95 |
+
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
modules/sd_samplers_compvis.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import ldm.models.diffusion.ddim
|
3 |
+
import ldm.models.diffusion.plms
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from modules.shared import state
|
9 |
+
from modules import sd_samplers_common, prompt_parser, shared
|
10 |
+
import modules.models.diffusion.uni_pc
|
11 |
+
|
12 |
+
|
13 |
+
samplers_data_compvis = [
|
14 |
+
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
|
15 |
+
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
|
16 |
+
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
class VanillaStableDiffusionSampler:
|
21 |
+
def __init__(self, constructor, sd_model):
|
22 |
+
self.sampler = constructor(sd_model)
|
23 |
+
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
|
24 |
+
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
25 |
+
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
|
26 |
+
self.orig_p_sample_ddim = None
|
27 |
+
if self.is_plms:
|
28 |
+
self.orig_p_sample_ddim = self.sampler.p_sample_plms
|
29 |
+
elif self.is_ddim:
|
30 |
+
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
|
31 |
+
self.mask = None
|
32 |
+
self.nmask = None
|
33 |
+
self.init_latent = None
|
34 |
+
self.sampler_noises = None
|
35 |
+
self.step = 0
|
36 |
+
self.stop_at = None
|
37 |
+
self.eta = None
|
38 |
+
self.config = None
|
39 |
+
self.last_latent = None
|
40 |
+
|
41 |
+
self.conditioning_key = sd_model.model.conditioning_key
|
42 |
+
|
43 |
+
def number_of_needed_noises(self, p):
|
44 |
+
return 0
|
45 |
+
|
46 |
+
def launch_sampling(self, steps, func):
|
47 |
+
state.sampling_steps = steps
|
48 |
+
state.sampling_step = 0
|
49 |
+
|
50 |
+
try:
|
51 |
+
return func()
|
52 |
+
except sd_samplers_common.InterruptedException:
|
53 |
+
return self.last_latent
|
54 |
+
|
55 |
+
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
56 |
+
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
57 |
+
|
58 |
+
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
|
59 |
+
|
60 |
+
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
61 |
+
|
62 |
+
return res
|
63 |
+
|
64 |
+
def before_sample(self, x, ts, cond, unconditional_conditioning):
|
65 |
+
if state.interrupted or state.skipped:
|
66 |
+
raise sd_samplers_common.InterruptedException
|
67 |
+
|
68 |
+
if self.stop_at is not None and self.step > self.stop_at:
|
69 |
+
raise sd_samplers_common.InterruptedException
|
70 |
+
|
71 |
+
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
72 |
+
image_conditioning = None
|
73 |
+
uc_image_conditioning = None
|
74 |
+
if isinstance(cond, dict):
|
75 |
+
if self.conditioning_key == "crossattn-adm":
|
76 |
+
image_conditioning = cond["c_adm"]
|
77 |
+
uc_image_conditioning = unconditional_conditioning["c_adm"]
|
78 |
+
else:
|
79 |
+
image_conditioning = cond["c_concat"][0]
|
80 |
+
cond = cond["c_crossattn"][0]
|
81 |
+
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
82 |
+
|
83 |
+
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
84 |
+
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
85 |
+
|
86 |
+
assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
|
87 |
+
cond = tensor
|
88 |
+
|
89 |
+
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
90 |
+
# filling unconditional_conditioning with repeats of the last vector to match length is
|
91 |
+
# not 100% correct but should work well enough
|
92 |
+
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
93 |
+
last_vector = unconditional_conditioning[:, -1:]
|
94 |
+
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
95 |
+
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
96 |
+
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
97 |
+
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
98 |
+
|
99 |
+
if self.mask is not None:
|
100 |
+
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
101 |
+
x = img_orig * self.mask + self.nmask * x
|
102 |
+
|
103 |
+
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
104 |
+
# Note that they need to be lists because it just concatenates them later.
|
105 |
+
if image_conditioning is not None:
|
106 |
+
if self.conditioning_key == "crossattn-adm":
|
107 |
+
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
|
108 |
+
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
|
109 |
+
else:
|
110 |
+
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
111 |
+
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
112 |
+
|
113 |
+
return x, ts, cond, unconditional_conditioning
|
114 |
+
|
115 |
+
def update_step(self, last_latent):
|
116 |
+
if self.mask is not None:
|
117 |
+
self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
|
118 |
+
else:
|
119 |
+
self.last_latent = last_latent
|
120 |
+
|
121 |
+
sd_samplers_common.store_latent(self.last_latent)
|
122 |
+
|
123 |
+
self.step += 1
|
124 |
+
state.sampling_step = self.step
|
125 |
+
shared.total_tqdm.update()
|
126 |
+
|
127 |
+
def after_sample(self, x, ts, cond, uncond, res):
|
128 |
+
if not self.is_unipc:
|
129 |
+
self.update_step(res[1])
|
130 |
+
|
131 |
+
return x, ts, cond, uncond, res
|
132 |
+
|
133 |
+
def unipc_after_update(self, x, model_x):
|
134 |
+
self.update_step(x)
|
135 |
+
|
136 |
+
def initialize(self, p):
|
137 |
+
if self.is_ddim:
|
138 |
+
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
139 |
+
else:
|
140 |
+
self.eta = 0.0
|
141 |
+
|
142 |
+
if self.eta != 0.0:
|
143 |
+
p.extra_generation_params["Eta DDIM"] = self.eta
|
144 |
+
|
145 |
+
if self.is_unipc:
|
146 |
+
keys = [
|
147 |
+
('UniPC variant', 'uni_pc_variant'),
|
148 |
+
('UniPC skip type', 'uni_pc_skip_type'),
|
149 |
+
('UniPC order', 'uni_pc_order'),
|
150 |
+
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
151 |
+
]
|
152 |
+
|
153 |
+
for name, key in keys:
|
154 |
+
v = getattr(shared.opts, key)
|
155 |
+
if v != shared.opts.get_default(key):
|
156 |
+
p.extra_generation_params[name] = v
|
157 |
+
|
158 |
+
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
159 |
+
if hasattr(self.sampler, fieldname):
|
160 |
+
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
161 |
+
if self.is_unipc:
|
162 |
+
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
|
163 |
+
|
164 |
+
self.mask = p.mask if hasattr(p, 'mask') else None
|
165 |
+
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
166 |
+
|
167 |
+
|
168 |
+
def adjust_steps_if_invalid(self, p, num_steps):
|
169 |
+
if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
|
170 |
+
if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
|
171 |
+
num_steps = shared.opts.uni_pc_order
|
172 |
+
valid_step = 999 / (1000 // num_steps)
|
173 |
+
if valid_step == math.floor(valid_step):
|
174 |
+
return int(valid_step) + 1
|
175 |
+
|
176 |
+
return num_steps
|
177 |
+
|
178 |
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
179 |
+
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
180 |
+
steps = self.adjust_steps_if_invalid(p, steps)
|
181 |
+
self.initialize(p)
|
182 |
+
|
183 |
+
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
184 |
+
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
185 |
+
|
186 |
+
self.init_latent = x
|
187 |
+
self.last_latent = x
|
188 |
+
self.step = 0
|
189 |
+
|
190 |
+
# Wrap the conditioning models with additional image conditioning for inpainting model
|
191 |
+
if image_conditioning is not None:
|
192 |
+
if self.conditioning_key == "crossattn-adm":
|
193 |
+
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
|
194 |
+
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
|
195 |
+
else:
|
196 |
+
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
197 |
+
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
198 |
+
|
199 |
+
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
200 |
+
|
201 |
+
return samples
|
202 |
+
|
203 |
+
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
204 |
+
self.initialize(p)
|
205 |
+
|
206 |
+
self.init_latent = None
|
207 |
+
self.last_latent = x
|
208 |
+
self.step = 0
|
209 |
+
|
210 |
+
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
211 |
+
|
212 |
+
# Wrap the conditioning models with additional image conditioning for inpainting model
|
213 |
+
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
214 |
+
if image_conditioning is not None:
|
215 |
+
if self.conditioning_key == "crossattn-adm":
|
216 |
+
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
|
217 |
+
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
|
218 |
+
else:
|
219 |
+
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
220 |
+
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
221 |
+
|
222 |
+
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
223 |
+
|
224 |
+
return samples_ddim
|
modules/sd_samplers_kdiffusion.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import deque
|
2 |
+
import torch
|
3 |
+
import inspect
|
4 |
+
import k_diffusion.sampling
|
5 |
+
from modules import prompt_parser, devices, sd_samplers_common
|
6 |
+
|
7 |
+
from modules.shared import opts, state
|
8 |
+
import modules.shared as shared
|
9 |
+
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
10 |
+
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
11 |
+
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
12 |
+
|
13 |
+
samplers_k_diffusion = [
|
14 |
+
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
15 |
+
('Euler', 'sample_euler', ['k_euler'], {}),
|
16 |
+
('LMS', 'sample_lms', ['k_lms'], {}),
|
17 |
+
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
18 |
+
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
19 |
+
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
|
20 |
+
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
|
21 |
+
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
22 |
+
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
|
23 |
+
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
|
24 |
+
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
25 |
+
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
26 |
+
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
27 |
+
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
28 |
+
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
29 |
+
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
30 |
+
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
31 |
+
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
32 |
+
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
33 |
+
]
|
34 |
+
|
35 |
+
samplers_data_k_diffusion = [
|
36 |
+
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
37 |
+
for label, funcname, aliases, options in samplers_k_diffusion
|
38 |
+
if hasattr(k_diffusion.sampling, funcname)
|
39 |
+
]
|
40 |
+
|
41 |
+
sampler_extra_params = {
|
42 |
+
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
43 |
+
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
44 |
+
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
45 |
+
}
|
46 |
+
|
47 |
+
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
48 |
+
k_diffusion_scheduler = {
|
49 |
+
'Automatic': None,
|
50 |
+
'karras': k_diffusion.sampling.get_sigmas_karras,
|
51 |
+
'exponential': k_diffusion.sampling.get_sigmas_exponential,
|
52 |
+
'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
def catenate_conds(conds):
|
57 |
+
if not isinstance(conds[0], dict):
|
58 |
+
return torch.cat(conds)
|
59 |
+
|
60 |
+
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
61 |
+
|
62 |
+
|
63 |
+
def subscript_cond(cond, a, b):
|
64 |
+
if not isinstance(cond, dict):
|
65 |
+
return cond[a:b]
|
66 |
+
|
67 |
+
return {key: vec[a:b] for key, vec in cond.items()}
|
68 |
+
|
69 |
+
|
70 |
+
def pad_cond(tensor, repeats, empty):
|
71 |
+
if not isinstance(tensor, dict):
|
72 |
+
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
73 |
+
|
74 |
+
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
75 |
+
return tensor
|
76 |
+
|
77 |
+
|
78 |
+
class CFGDenoiser(torch.nn.Module):
|
79 |
+
"""
|
80 |
+
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
81 |
+
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
82 |
+
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
83 |
+
negative prompt.
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, model):
|
87 |
+
super().__init__()
|
88 |
+
self.inner_model = model
|
89 |
+
self.mask = None
|
90 |
+
self.nmask = None
|
91 |
+
self.init_latent = None
|
92 |
+
self.step = 0
|
93 |
+
self.image_cfg_scale = None
|
94 |
+
self.padded_cond_uncond = False
|
95 |
+
|
96 |
+
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
97 |
+
denoised_uncond = x_out[-uncond.shape[0]:]
|
98 |
+
denoised = torch.clone(denoised_uncond)
|
99 |
+
|
100 |
+
for i, conds in enumerate(conds_list):
|
101 |
+
for cond_index, weight in conds:
|
102 |
+
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
103 |
+
|
104 |
+
return denoised
|
105 |
+
|
106 |
+
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
107 |
+
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
108 |
+
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
109 |
+
|
110 |
+
return denoised
|
111 |
+
|
112 |
+
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
113 |
+
if state.interrupted or state.skipped:
|
114 |
+
raise sd_samplers_common.InterruptedException
|
115 |
+
|
116 |
+
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
117 |
+
# so is_edit_model is set to False to support AND composition.
|
118 |
+
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
119 |
+
|
120 |
+
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
121 |
+
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
122 |
+
|
123 |
+
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
124 |
+
|
125 |
+
batch_size = len(conds_list)
|
126 |
+
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
127 |
+
|
128 |
+
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
129 |
+
image_uncond = torch.zeros_like(image_cond)
|
130 |
+
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
131 |
+
else:
|
132 |
+
image_uncond = image_cond
|
133 |
+
if isinstance(uncond, dict):
|
134 |
+
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
135 |
+
else:
|
136 |
+
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
137 |
+
|
138 |
+
if not is_edit_model:
|
139 |
+
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
140 |
+
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
141 |
+
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
142 |
+
else:
|
143 |
+
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
144 |
+
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
145 |
+
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
146 |
+
|
147 |
+
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
148 |
+
cfg_denoiser_callback(denoiser_params)
|
149 |
+
x_in = denoiser_params.x
|
150 |
+
image_cond_in = denoiser_params.image_cond
|
151 |
+
sigma_in = denoiser_params.sigma
|
152 |
+
tensor = denoiser_params.text_cond
|
153 |
+
uncond = denoiser_params.text_uncond
|
154 |
+
skip_uncond = False
|
155 |
+
|
156 |
+
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
157 |
+
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
158 |
+
skip_uncond = True
|
159 |
+
x_in = x_in[:-batch_size]
|
160 |
+
sigma_in = sigma_in[:-batch_size]
|
161 |
+
|
162 |
+
self.padded_cond_uncond = False
|
163 |
+
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
164 |
+
empty = shared.sd_model.cond_stage_model_empty_prompt
|
165 |
+
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
166 |
+
|
167 |
+
if num_repeats < 0:
|
168 |
+
tensor = pad_cond(tensor, -num_repeats, empty)
|
169 |
+
self.padded_cond_uncond = True
|
170 |
+
elif num_repeats > 0:
|
171 |
+
uncond = pad_cond(uncond, num_repeats, empty)
|
172 |
+
self.padded_cond_uncond = True
|
173 |
+
|
174 |
+
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
175 |
+
if is_edit_model:
|
176 |
+
cond_in = catenate_conds([tensor, uncond, uncond])
|
177 |
+
elif skip_uncond:
|
178 |
+
cond_in = tensor
|
179 |
+
else:
|
180 |
+
cond_in = catenate_conds([tensor, uncond])
|
181 |
+
|
182 |
+
if shared.batch_cond_uncond:
|
183 |
+
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
184 |
+
else:
|
185 |
+
x_out = torch.zeros_like(x_in)
|
186 |
+
for batch_offset in range(0, x_out.shape[0], batch_size):
|
187 |
+
a = batch_offset
|
188 |
+
b = a + batch_size
|
189 |
+
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
190 |
+
else:
|
191 |
+
x_out = torch.zeros_like(x_in)
|
192 |
+
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
193 |
+
for batch_offset in range(0, tensor.shape[0], batch_size):
|
194 |
+
a = batch_offset
|
195 |
+
b = min(a + batch_size, tensor.shape[0])
|
196 |
+
|
197 |
+
if not is_edit_model:
|
198 |
+
c_crossattn = subscript_cond(tensor, a, b)
|
199 |
+
else:
|
200 |
+
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
201 |
+
|
202 |
+
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
203 |
+
|
204 |
+
if not skip_uncond:
|
205 |
+
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
206 |
+
|
207 |
+
denoised_image_indexes = [x[0][0] for x in conds_list]
|
208 |
+
if skip_uncond:
|
209 |
+
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
210 |
+
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
211 |
+
|
212 |
+
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
213 |
+
cfg_denoised_callback(denoised_params)
|
214 |
+
|
215 |
+
devices.test_for_nans(x_out, "unet")
|
216 |
+
|
217 |
+
if opts.live_preview_content == "Prompt":
|
218 |
+
sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
|
219 |
+
elif opts.live_preview_content == "Negative prompt":
|
220 |
+
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
221 |
+
|
222 |
+
if is_edit_model:
|
223 |
+
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
224 |
+
elif skip_uncond:
|
225 |
+
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
226 |
+
else:
|
227 |
+
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
228 |
+
|
229 |
+
if self.mask is not None:
|
230 |
+
denoised = self.init_latent * self.mask + self.nmask * denoised
|
231 |
+
|
232 |
+
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
233 |
+
cfg_after_cfg_callback(after_cfg_callback_params)
|
234 |
+
denoised = after_cfg_callback_params.x
|
235 |
+
|
236 |
+
self.step += 1
|
237 |
+
return denoised
|
238 |
+
|
239 |
+
|
240 |
+
class TorchHijack:
|
241 |
+
def __init__(self, sampler_noises):
|
242 |
+
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
243 |
+
# implementation.
|
244 |
+
self.sampler_noises = deque(sampler_noises)
|
245 |
+
|
246 |
+
def __getattr__(self, item):
|
247 |
+
if item == 'randn_like':
|
248 |
+
return self.randn_like
|
249 |
+
|
250 |
+
if hasattr(torch, item):
|
251 |
+
return getattr(torch, item)
|
252 |
+
|
253 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
254 |
+
|
255 |
+
def randn_like(self, x):
|
256 |
+
if self.sampler_noises:
|
257 |
+
noise = self.sampler_noises.popleft()
|
258 |
+
if noise.shape == x.shape:
|
259 |
+
return noise
|
260 |
+
|
261 |
+
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
262 |
+
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
263 |
+
else:
|
264 |
+
return torch.randn_like(x)
|
265 |
+
|
266 |
+
|
267 |
+
class KDiffusionSampler:
|
268 |
+
def __init__(self, funcname, sd_model):
|
269 |
+
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
270 |
+
|
271 |
+
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
272 |
+
self.funcname = funcname
|
273 |
+
self.func = getattr(k_diffusion.sampling, self.funcname)
|
274 |
+
self.extra_params = sampler_extra_params.get(funcname, [])
|
275 |
+
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
276 |
+
self.sampler_noises = None
|
277 |
+
self.stop_at = None
|
278 |
+
self.eta = None
|
279 |
+
self.config = None # set by the function calling the constructor
|
280 |
+
self.last_latent = None
|
281 |
+
self.s_min_uncond = None
|
282 |
+
|
283 |
+
self.conditioning_key = sd_model.model.conditioning_key
|
284 |
+
|
285 |
+
def callback_state(self, d):
|
286 |
+
step = d['i']
|
287 |
+
latent = d["denoised"]
|
288 |
+
if opts.live_preview_content == "Combined":
|
289 |
+
sd_samplers_common.store_latent(latent)
|
290 |
+
self.last_latent = latent
|
291 |
+
|
292 |
+
if self.stop_at is not None and step > self.stop_at:
|
293 |
+
raise sd_samplers_common.InterruptedException
|
294 |
+
|
295 |
+
state.sampling_step = step
|
296 |
+
shared.total_tqdm.update()
|
297 |
+
|
298 |
+
def launch_sampling(self, steps, func):
|
299 |
+
state.sampling_steps = steps
|
300 |
+
state.sampling_step = 0
|
301 |
+
|
302 |
+
try:
|
303 |
+
return func()
|
304 |
+
except RecursionError:
|
305 |
+
print(
|
306 |
+
'Encountered RecursionError during sampling, returning last latent. '
|
307 |
+
'rho >5 with a polyexponential scheduler may cause this error. '
|
308 |
+
'You should try to use a smaller rho value instead.'
|
309 |
+
)
|
310 |
+
return self.last_latent
|
311 |
+
except sd_samplers_common.InterruptedException:
|
312 |
+
return self.last_latent
|
313 |
+
|
314 |
+
def number_of_needed_noises(self, p):
|
315 |
+
return p.steps
|
316 |
+
|
317 |
+
def initialize(self, p):
|
318 |
+
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
319 |
+
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
320 |
+
self.model_wrap_cfg.step = 0
|
321 |
+
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
322 |
+
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
323 |
+
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
324 |
+
|
325 |
+
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
326 |
+
|
327 |
+
extra_params_kwargs = {}
|
328 |
+
for param_name in self.extra_params:
|
329 |
+
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
330 |
+
extra_params_kwargs[param_name] = getattr(p, param_name)
|
331 |
+
|
332 |
+
if 'eta' in inspect.signature(self.func).parameters:
|
333 |
+
if self.eta != 1.0:
|
334 |
+
p.extra_generation_params["Eta"] = self.eta
|
335 |
+
|
336 |
+
extra_params_kwargs['eta'] = self.eta
|
337 |
+
|
338 |
+
return extra_params_kwargs
|
339 |
+
|
340 |
+
def get_sigmas(self, p, steps):
|
341 |
+
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
342 |
+
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
343 |
+
discard_next_to_last_sigma = True
|
344 |
+
p.extra_generation_params["Discard penultimate sigma"] = True
|
345 |
+
|
346 |
+
steps += 1 if discard_next_to_last_sigma else 0
|
347 |
+
|
348 |
+
if p.sampler_noise_scheduler_override:
|
349 |
+
sigmas = p.sampler_noise_scheduler_override(steps)
|
350 |
+
elif opts.k_sched_type != "Automatic":
|
351 |
+
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
352 |
+
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
|
353 |
+
sigmas_kwargs = {
|
354 |
+
'sigma_min': sigma_min,
|
355 |
+
'sigma_max': sigma_max,
|
356 |
+
}
|
357 |
+
|
358 |
+
sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
|
359 |
+
p.extra_generation_params["Schedule type"] = opts.k_sched_type
|
360 |
+
|
361 |
+
if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
|
362 |
+
sigmas_kwargs['sigma_min'] = opts.sigma_min
|
363 |
+
p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
|
364 |
+
if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
|
365 |
+
sigmas_kwargs['sigma_max'] = opts.sigma_max
|
366 |
+
p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
|
367 |
+
|
368 |
+
default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
|
369 |
+
|
370 |
+
if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
|
371 |
+
sigmas_kwargs['rho'] = opts.rho
|
372 |
+
p.extra_generation_params["Schedule rho"] = opts.rho
|
373 |
+
|
374 |
+
sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
|
375 |
+
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
376 |
+
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
377 |
+
|
378 |
+
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
379 |
+
else:
|
380 |
+
sigmas = self.model_wrap.get_sigmas(steps)
|
381 |
+
|
382 |
+
if discard_next_to_last_sigma:
|
383 |
+
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
384 |
+
|
385 |
+
return sigmas
|
386 |
+
|
387 |
+
def create_noise_sampler(self, x, sigmas, p):
|
388 |
+
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
389 |
+
if shared.opts.no_dpmpp_sde_batch_determinism:
|
390 |
+
return None
|
391 |
+
|
392 |
+
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
393 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
394 |
+
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
395 |
+
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
396 |
+
|
397 |
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
398 |
+
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
399 |
+
|
400 |
+
sigmas = self.get_sigmas(p, steps)
|
401 |
+
|
402 |
+
sigma_sched = sigmas[steps - t_enc - 1:]
|
403 |
+
xi = x + noise * sigma_sched[0]
|
404 |
+
|
405 |
+
extra_params_kwargs = self.initialize(p)
|
406 |
+
parameters = inspect.signature(self.func).parameters
|
407 |
+
|
408 |
+
if 'sigma_min' in parameters:
|
409 |
+
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
|
410 |
+
extra_params_kwargs['sigma_min'] = sigma_sched[-2]
|
411 |
+
if 'sigma_max' in parameters:
|
412 |
+
extra_params_kwargs['sigma_max'] = sigma_sched[0]
|
413 |
+
if 'n' in parameters:
|
414 |
+
extra_params_kwargs['n'] = len(sigma_sched) - 1
|
415 |
+
if 'sigma_sched' in parameters:
|
416 |
+
extra_params_kwargs['sigma_sched'] = sigma_sched
|
417 |
+
if 'sigmas' in parameters:
|
418 |
+
extra_params_kwargs['sigmas'] = sigma_sched
|
419 |
+
|
420 |
+
if self.config.options.get('brownian_noise', False):
|
421 |
+
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
422 |
+
extra_params_kwargs['noise_sampler'] = noise_sampler
|
423 |
+
|
424 |
+
self.model_wrap_cfg.init_latent = x
|
425 |
+
self.last_latent = x
|
426 |
+
extra_args = {
|
427 |
+
'cond': conditioning,
|
428 |
+
'image_cond': image_conditioning,
|
429 |
+
'uncond': unconditional_conditioning,
|
430 |
+
'cond_scale': p.cfg_scale,
|
431 |
+
's_min_uncond': self.s_min_uncond
|
432 |
+
}
|
433 |
+
|
434 |
+
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
435 |
+
|
436 |
+
if self.model_wrap_cfg.padded_cond_uncond:
|
437 |
+
p.extra_generation_params["Pad conds"] = True
|
438 |
+
|
439 |
+
return samples
|
440 |
+
|
441 |
+
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
442 |
+
steps = steps or p.steps
|
443 |
+
|
444 |
+
sigmas = self.get_sigmas(p, steps)
|
445 |
+
|
446 |
+
x = x * sigmas[0]
|
447 |
+
|
448 |
+
extra_params_kwargs = self.initialize(p)
|
449 |
+
parameters = inspect.signature(self.func).parameters
|
450 |
+
|
451 |
+
if 'sigma_min' in parameters:
|
452 |
+
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
453 |
+
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
454 |
+
if 'n' in parameters:
|
455 |
+
extra_params_kwargs['n'] = steps
|
456 |
+
else:
|
457 |
+
extra_params_kwargs['sigmas'] = sigmas
|
458 |
+
|
459 |
+
if self.config.options.get('brownian_noise', False):
|
460 |
+
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
461 |
+
extra_params_kwargs['noise_sampler'] = noise_sampler
|
462 |
+
|
463 |
+
self.last_latent = x
|
464 |
+
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
465 |
+
'cond': conditioning,
|
466 |
+
'image_cond': image_conditioning,
|
467 |
+
'uncond': unconditional_conditioning,
|
468 |
+
'cond_scale': p.cfg_scale,
|
469 |
+
's_min_uncond': self.s_min_uncond
|
470 |
+
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
471 |
+
|
472 |
+
if self.model_wrap_cfg.padded_cond_uncond:
|
473 |
+
p.extra_generation_params["Pad conds"] = True
|
474 |
+
|
475 |
+
return samples
|
476 |
+
|
modules/sd_unet.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn
|
2 |
+
import ldm.modules.diffusionmodules.openaimodel
|
3 |
+
|
4 |
+
from modules import script_callbacks, shared, devices
|
5 |
+
|
6 |
+
unet_options = []
|
7 |
+
current_unet_option = None
|
8 |
+
current_unet = None
|
9 |
+
|
10 |
+
|
11 |
+
def list_unets():
|
12 |
+
new_unets = script_callbacks.list_unets_callback()
|
13 |
+
|
14 |
+
unet_options.clear()
|
15 |
+
unet_options.extend(new_unets)
|
16 |
+
|
17 |
+
|
18 |
+
def get_unet_option(option=None):
|
19 |
+
option = option or shared.opts.sd_unet
|
20 |
+
|
21 |
+
if option == "None":
|
22 |
+
return None
|
23 |
+
|
24 |
+
if option == "Automatic":
|
25 |
+
name = shared.sd_model.sd_checkpoint_info.model_name
|
26 |
+
|
27 |
+
options = [x for x in unet_options if x.model_name == name]
|
28 |
+
|
29 |
+
option = options[0].label if options else "None"
|
30 |
+
|
31 |
+
return next(iter([x for x in unet_options if x.label == option]), None)
|
32 |
+
|
33 |
+
|
34 |
+
def apply_unet(option=None):
|
35 |
+
global current_unet_option
|
36 |
+
global current_unet
|
37 |
+
|
38 |
+
new_option = get_unet_option(option)
|
39 |
+
if new_option == current_unet_option:
|
40 |
+
return
|
41 |
+
|
42 |
+
if current_unet is not None:
|
43 |
+
print(f"Dectivating unet: {current_unet.option.label}")
|
44 |
+
current_unet.deactivate()
|
45 |
+
|
46 |
+
current_unet_option = new_option
|
47 |
+
if current_unet_option is None:
|
48 |
+
current_unet = None
|
49 |
+
|
50 |
+
if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
|
51 |
+
shared.sd_model.model.diffusion_model.to(devices.device)
|
52 |
+
|
53 |
+
return
|
54 |
+
|
55 |
+
shared.sd_model.model.diffusion_model.to(devices.cpu)
|
56 |
+
devices.torch_gc()
|
57 |
+
|
58 |
+
current_unet = current_unet_option.create_unet()
|
59 |
+
current_unet.option = current_unet_option
|
60 |
+
print(f"Activating unet: {current_unet.option.label}")
|
61 |
+
current_unet.activate()
|
62 |
+
|
63 |
+
|
64 |
+
class SdUnetOption:
|
65 |
+
model_name = None
|
66 |
+
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
|
67 |
+
|
68 |
+
label = None
|
69 |
+
"""name of the unet in UI"""
|
70 |
+
|
71 |
+
def create_unet(self):
|
72 |
+
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
|
73 |
+
raise NotImplementedError()
|
74 |
+
|
75 |
+
|
76 |
+
class SdUnet(torch.nn.Module):
|
77 |
+
def forward(self, x, timesteps, context, *args, **kwargs):
|
78 |
+
raise NotImplementedError()
|
79 |
+
|
80 |
+
def activate(self):
|
81 |
+
pass
|
82 |
+
|
83 |
+
def deactivate(self):
|
84 |
+
pass
|
85 |
+
|
86 |
+
|
87 |
+
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
88 |
+
if current_unet is not None:
|
89 |
+
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
90 |
+
|
91 |
+
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
|
92 |
+
|
modules/sd_vae.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import collections
|
3 |
+
from modules import paths, shared, devices, script_callbacks, sd_models
|
4 |
+
import glob
|
5 |
+
from copy import deepcopy
|
6 |
+
|
7 |
+
|
8 |
+
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
|
9 |
+
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
10 |
+
vae_dict = {}
|
11 |
+
|
12 |
+
|
13 |
+
base_vae = None
|
14 |
+
loaded_vae_file = None
|
15 |
+
checkpoint_info = None
|
16 |
+
|
17 |
+
checkpoints_loaded = collections.OrderedDict()
|
18 |
+
|
19 |
+
def get_base_vae(model):
|
20 |
+
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
21 |
+
return base_vae
|
22 |
+
return None
|
23 |
+
|
24 |
+
|
25 |
+
def store_base_vae(model):
|
26 |
+
global base_vae, checkpoint_info
|
27 |
+
if checkpoint_info != model.sd_checkpoint_info:
|
28 |
+
assert not loaded_vae_file, "Trying to store non-base VAE!"
|
29 |
+
base_vae = deepcopy(model.first_stage_model.state_dict())
|
30 |
+
checkpoint_info = model.sd_checkpoint_info
|
31 |
+
|
32 |
+
|
33 |
+
def delete_base_vae():
|
34 |
+
global base_vae, checkpoint_info
|
35 |
+
base_vae = None
|
36 |
+
checkpoint_info = None
|
37 |
+
|
38 |
+
|
39 |
+
def restore_base_vae(model):
|
40 |
+
global loaded_vae_file
|
41 |
+
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
|
42 |
+
print("Restoring base VAE")
|
43 |
+
_load_vae_dict(model, base_vae)
|
44 |
+
loaded_vae_file = None
|
45 |
+
delete_base_vae()
|
46 |
+
|
47 |
+
|
48 |
+
def get_filename(filepath):
|
49 |
+
return os.path.basename(filepath)
|
50 |
+
|
51 |
+
|
52 |
+
def refresh_vae_list():
|
53 |
+
vae_dict.clear()
|
54 |
+
|
55 |
+
paths = [
|
56 |
+
os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
|
57 |
+
os.path.join(sd_models.model_path, '**/*.vae.pt'),
|
58 |
+
os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
|
59 |
+
os.path.join(vae_path, '**/*.ckpt'),
|
60 |
+
os.path.join(vae_path, '**/*.pt'),
|
61 |
+
os.path.join(vae_path, '**/*.safetensors'),
|
62 |
+
]
|
63 |
+
|
64 |
+
if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
|
65 |
+
paths += [
|
66 |
+
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
|
67 |
+
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
|
68 |
+
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
|
69 |
+
]
|
70 |
+
|
71 |
+
if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):
|
72 |
+
paths += [
|
73 |
+
os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),
|
74 |
+
os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),
|
75 |
+
os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),
|
76 |
+
]
|
77 |
+
|
78 |
+
candidates = []
|
79 |
+
for path in paths:
|
80 |
+
candidates += glob.iglob(path, recursive=True)
|
81 |
+
|
82 |
+
for filepath in candidates:
|
83 |
+
name = get_filename(filepath)
|
84 |
+
vae_dict[name] = filepath
|
85 |
+
|
86 |
+
|
87 |
+
def find_vae_near_checkpoint(checkpoint_file):
|
88 |
+
checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
|
89 |
+
for vae_file in vae_dict.values():
|
90 |
+
if os.path.basename(vae_file).startswith(checkpoint_path):
|
91 |
+
return vae_file
|
92 |
+
|
93 |
+
return None
|
94 |
+
|
95 |
+
|
96 |
+
def resolve_vae(checkpoint_file):
|
97 |
+
if shared.cmd_opts.vae_path is not None:
|
98 |
+
return shared.cmd_opts.vae_path, 'from commandline argument'
|
99 |
+
|
100 |
+
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
101 |
+
|
102 |
+
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
103 |
+
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
|
104 |
+
return vae_near_checkpoint, 'found near the checkpoint'
|
105 |
+
|
106 |
+
if shared.opts.sd_vae == "None":
|
107 |
+
return None, None
|
108 |
+
|
109 |
+
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
110 |
+
if vae_from_options is not None:
|
111 |
+
return vae_from_options, 'specified in settings'
|
112 |
+
|
113 |
+
if not is_automatic:
|
114 |
+
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
115 |
+
|
116 |
+
return None, None
|
117 |
+
|
118 |
+
|
119 |
+
def load_vae_dict(filename, map_location):
|
120 |
+
vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
|
121 |
+
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
122 |
+
return vae_dict_1
|
123 |
+
|
124 |
+
|
125 |
+
def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
126 |
+
global vae_dict, loaded_vae_file
|
127 |
+
# save_settings = False
|
128 |
+
|
129 |
+
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
130 |
+
|
131 |
+
if vae_file:
|
132 |
+
if cache_enabled and vae_file in checkpoints_loaded:
|
133 |
+
# use vae checkpoint cache
|
134 |
+
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
|
135 |
+
store_base_vae(model)
|
136 |
+
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
137 |
+
else:
|
138 |
+
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
|
139 |
+
print(f"Loading VAE weights {vae_source}: {vae_file}")
|
140 |
+
store_base_vae(model)
|
141 |
+
|
142 |
+
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
|
143 |
+
_load_vae_dict(model, vae_dict_1)
|
144 |
+
|
145 |
+
if cache_enabled:
|
146 |
+
# cache newly loaded vae
|
147 |
+
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
148 |
+
|
149 |
+
# clean up cache if limit is reached
|
150 |
+
if cache_enabled:
|
151 |
+
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
152 |
+
checkpoints_loaded.popitem(last=False) # LRU
|
153 |
+
|
154 |
+
# If vae used is not in dict, update it
|
155 |
+
# It will be removed on refresh though
|
156 |
+
vae_opt = get_filename(vae_file)
|
157 |
+
if vae_opt not in vae_dict:
|
158 |
+
vae_dict[vae_opt] = vae_file
|
159 |
+
|
160 |
+
elif loaded_vae_file:
|
161 |
+
restore_base_vae(model)
|
162 |
+
|
163 |
+
loaded_vae_file = vae_file
|
164 |
+
|
165 |
+
|
166 |
+
# don't call this from outside
|
167 |
+
def _load_vae_dict(model, vae_dict_1):
|
168 |
+
model.first_stage_model.load_state_dict(vae_dict_1)
|
169 |
+
model.first_stage_model.to(devices.dtype_vae)
|
170 |
+
|
171 |
+
|
172 |
+
def clear_loaded_vae():
|
173 |
+
global loaded_vae_file
|
174 |
+
loaded_vae_file = None
|
175 |
+
|
176 |
+
|
177 |
+
unspecified = object()
|
178 |
+
|
179 |
+
|
180 |
+
def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
181 |
+
from modules import lowvram, devices, sd_hijack
|
182 |
+
|
183 |
+
if not sd_model:
|
184 |
+
sd_model = shared.sd_model
|
185 |
+
|
186 |
+
checkpoint_info = sd_model.sd_checkpoint_info
|
187 |
+
checkpoint_file = checkpoint_info.filename
|
188 |
+
|
189 |
+
if vae_file == unspecified:
|
190 |
+
vae_file, vae_source = resolve_vae(checkpoint_file)
|
191 |
+
else:
|
192 |
+
vae_source = "from function argument"
|
193 |
+
|
194 |
+
if loaded_vae_file == vae_file:
|
195 |
+
return
|
196 |
+
|
197 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
198 |
+
lowvram.send_everything_to_cpu()
|
199 |
+
else:
|
200 |
+
sd_model.to(devices.cpu)
|
201 |
+
|
202 |
+
sd_hijack.model_hijack.undo_hijack(sd_model)
|
203 |
+
|
204 |
+
load_vae(sd_model, vae_file, vae_source)
|
205 |
+
|
206 |
+
sd_hijack.model_hijack.hijack(sd_model)
|
207 |
+
script_callbacks.model_loaded_callback(sd_model)
|
208 |
+
|
209 |
+
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
210 |
+
sd_model.to(devices.device)
|
211 |
+
|
212 |
+
print("VAE weights loaded.")
|
213 |
+
return sd_model
|
modules/sd_vae_approx.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from modules import devices, paths, shared
|
6 |
+
|
7 |
+
sd_vae_approx_models = {}
|
8 |
+
|
9 |
+
|
10 |
+
class VAEApprox(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super(VAEApprox, self).__init__()
|
13 |
+
self.conv1 = nn.Conv2d(4, 8, (7, 7))
|
14 |
+
self.conv2 = nn.Conv2d(8, 16, (5, 5))
|
15 |
+
self.conv3 = nn.Conv2d(16, 32, (3, 3))
|
16 |
+
self.conv4 = nn.Conv2d(32, 64, (3, 3))
|
17 |
+
self.conv5 = nn.Conv2d(64, 32, (3, 3))
|
18 |
+
self.conv6 = nn.Conv2d(32, 16, (3, 3))
|
19 |
+
self.conv7 = nn.Conv2d(16, 8, (3, 3))
|
20 |
+
self.conv8 = nn.Conv2d(8, 3, (3, 3))
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
extra = 11
|
24 |
+
x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
|
25 |
+
x = nn.functional.pad(x, (extra, extra, extra, extra))
|
26 |
+
|
27 |
+
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
|
28 |
+
x = layer(x)
|
29 |
+
x = nn.functional.leaky_relu(x, 0.1)
|
30 |
+
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
def download_model(model_path, model_url):
|
35 |
+
if not os.path.exists(model_path):
|
36 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
37 |
+
|
38 |
+
print(f'Downloading VAEApprox model to: {model_path}')
|
39 |
+
torch.hub.download_url_to_file(model_url, model_path)
|
40 |
+
|
41 |
+
|
42 |
+
def model():
|
43 |
+
model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
|
44 |
+
loaded_model = sd_vae_approx_models.get(model_name)
|
45 |
+
|
46 |
+
if loaded_model is None:
|
47 |
+
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
|
48 |
+
if not os.path.exists(model_path):
|
49 |
+
model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name)
|
50 |
+
|
51 |
+
if not os.path.exists(model_path):
|
52 |
+
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
|
53 |
+
download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
|
54 |
+
|
55 |
+
loaded_model = VAEApprox()
|
56 |
+
loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
57 |
+
loaded_model.eval()
|
58 |
+
loaded_model.to(devices.device, devices.dtype)
|
59 |
+
sd_vae_approx_models[model_name] = loaded_model
|
60 |
+
|
61 |
+
return loaded_model
|
62 |
+
|
63 |
+
|
64 |
+
def cheap_approximation(sample):
|
65 |
+
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
66 |
+
|
67 |
+
if shared.sd_model.is_sdxl:
|
68 |
+
coeffs = [
|
69 |
+
[ 0.3448, 0.4168, 0.4395],
|
70 |
+
[-0.1953, -0.0290, 0.0250],
|
71 |
+
[ 0.1074, 0.0886, -0.0163],
|
72 |
+
[-0.3730, -0.2499, -0.2088],
|
73 |
+
]
|
74 |
+
else:
|
75 |
+
coeffs = [
|
76 |
+
[ 0.298, 0.207, 0.208],
|
77 |
+
[ 0.187, 0.286, 0.173],
|
78 |
+
[-0.158, 0.189, 0.264],
|
79 |
+
[-0.184, -0.271, -0.473],
|
80 |
+
]
|
81 |
+
|
82 |
+
coefs = torch.tensor(coeffs).to(sample.device)
|
83 |
+
|
84 |
+
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
85 |
+
|
86 |
+
return x_sample
|
modules/sd_vae_taesd.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Tiny AutoEncoder for Stable Diffusion
|
3 |
+
(DNN for encoding / decoding SD's latent space)
|
4 |
+
|
5 |
+
https://github.com/madebyollin/taesd
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from modules import devices, paths_internal, shared
|
12 |
+
|
13 |
+
sd_vae_taesd_models = {}
|
14 |
+
|
15 |
+
|
16 |
+
def conv(n_in, n_out, **kwargs):
|
17 |
+
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
18 |
+
|
19 |
+
|
20 |
+
class Clamp(nn.Module):
|
21 |
+
@staticmethod
|
22 |
+
def forward(x):
|
23 |
+
return torch.tanh(x / 3) * 3
|
24 |
+
|
25 |
+
|
26 |
+
class Block(nn.Module):
|
27 |
+
def __init__(self, n_in, n_out):
|
28 |
+
super().__init__()
|
29 |
+
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
30 |
+
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
31 |
+
self.fuse = nn.ReLU()
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return self.fuse(self.conv(x) + self.skip(x))
|
35 |
+
|
36 |
+
|
37 |
+
def decoder():
|
38 |
+
return nn.Sequential(
|
39 |
+
Clamp(), conv(4, 64), nn.ReLU(),
|
40 |
+
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
41 |
+
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
42 |
+
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
43 |
+
Block(64, 64), conv(64, 3),
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
class TAESD(nn.Module):
|
48 |
+
latent_magnitude = 3
|
49 |
+
latent_shift = 0.5
|
50 |
+
|
51 |
+
def __init__(self, decoder_path="taesd_decoder.pth"):
|
52 |
+
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
53 |
+
super().__init__()
|
54 |
+
self.decoder = decoder()
|
55 |
+
self.decoder.load_state_dict(
|
56 |
+
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def unscale_latents(x):
|
60 |
+
"""[0, 1] -> raw latents"""
|
61 |
+
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
62 |
+
|
63 |
+
|
64 |
+
def download_model(model_path, model_url):
|
65 |
+
if not os.path.exists(model_path):
|
66 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
67 |
+
|
68 |
+
print(f'Downloading TAESD decoder to: {model_path}')
|
69 |
+
torch.hub.download_url_to_file(model_url, model_path)
|
70 |
+
|
71 |
+
|
72 |
+
def model():
|
73 |
+
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
|
74 |
+
loaded_model = sd_vae_taesd_models.get(model_name)
|
75 |
+
|
76 |
+
if loaded_model is None:
|
77 |
+
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
|
78 |
+
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
|
79 |
+
|
80 |
+
if os.path.exists(model_path):
|
81 |
+
loaded_model = TAESD(model_path)
|
82 |
+
loaded_model.eval()
|
83 |
+
loaded_model.to(devices.device, devices.dtype)
|
84 |
+
sd_vae_taesd_models[model_name] = loaded_model
|
85 |
+
else:
|
86 |
+
raise FileNotFoundError('TAESD model not found')
|
87 |
+
|
88 |
+
return loaded_model.decoder
|
modules/shared.py
ADDED
@@ -0,0 +1,912 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import sys
|
6 |
+
import threading
|
7 |
+
import time
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import torch
|
12 |
+
import tqdm
|
13 |
+
|
14 |
+
import launch
|
15 |
+
import modules.interrogate
|
16 |
+
import modules.memmon
|
17 |
+
import modules.styles
|
18 |
+
import modules.devices as devices
|
19 |
+
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
20 |
+
|
21 |
+
from modules.generation_parameters_copypaste import infotext_to_setting_name_mapping
|
22 |
+
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
23 |
+
from ldm.models.diffusion.ddpm import LatentDiffusion
|
24 |
+
from typing import Optional
|
25 |
+
|
26 |
+
log = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
demo = None
|
29 |
+
|
30 |
+
parser = cmd_args.parser
|
31 |
+
|
32 |
+
script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
|
33 |
+
script_loading.preload_extensions(extensions_builtin_dir, parser)
|
34 |
+
|
35 |
+
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
36 |
+
cmd_opts = parser.parse_args()
|
37 |
+
else:
|
38 |
+
cmd_opts, _ = parser.parse_known_args()
|
39 |
+
|
40 |
+
|
41 |
+
restricted_opts = {
|
42 |
+
"samples_filename_pattern",
|
43 |
+
"directories_filename_pattern",
|
44 |
+
"outdir_samples",
|
45 |
+
"outdir_txt2img_samples",
|
46 |
+
"outdir_img2img_samples",
|
47 |
+
"outdir_extras_samples",
|
48 |
+
"outdir_grids",
|
49 |
+
"outdir_txt2img_grids",
|
50 |
+
"outdir_save",
|
51 |
+
"outdir_init_images"
|
52 |
+
}
|
53 |
+
|
54 |
+
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
55 |
+
gradio_hf_hub_themes = [
|
56 |
+
"gradio/glass",
|
57 |
+
"gradio/monochrome",
|
58 |
+
"gradio/seafoam",
|
59 |
+
"gradio/soft",
|
60 |
+
"freddyaboulton/dracula_revamped",
|
61 |
+
"gradio/dracula_test",
|
62 |
+
"abidlabs/dracula_test",
|
63 |
+
"abidlabs/pakistan",
|
64 |
+
"dawood/microsoft_windows",
|
65 |
+
"ysharma/steampunk"
|
66 |
+
]
|
67 |
+
|
68 |
+
|
69 |
+
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
70 |
+
|
71 |
+
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
72 |
+
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
73 |
+
|
74 |
+
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
75 |
+
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
76 |
+
|
77 |
+
device = devices.device
|
78 |
+
weight_load_location = None if cmd_opts.lowram else "cpu"
|
79 |
+
|
80 |
+
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
81 |
+
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
82 |
+
xformers_available = False
|
83 |
+
config_filename = cmd_opts.ui_settings_file
|
84 |
+
|
85 |
+
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
86 |
+
hypernetworks = {}
|
87 |
+
loaded_hypernetworks = []
|
88 |
+
|
89 |
+
|
90 |
+
def reload_hypernetworks():
|
91 |
+
from modules.hypernetworks import hypernetwork
|
92 |
+
global hypernetworks
|
93 |
+
|
94 |
+
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
95 |
+
|
96 |
+
|
97 |
+
class State:
|
98 |
+
skipped = False
|
99 |
+
interrupted = False
|
100 |
+
job = ""
|
101 |
+
job_no = 0
|
102 |
+
job_count = 0
|
103 |
+
processing_has_refined_job_count = False
|
104 |
+
job_timestamp = '0'
|
105 |
+
sampling_step = 0
|
106 |
+
sampling_steps = 0
|
107 |
+
current_latent = None
|
108 |
+
current_image = None
|
109 |
+
current_image_sampling_step = 0
|
110 |
+
id_live_preview = 0
|
111 |
+
textinfo = None
|
112 |
+
time_start = None
|
113 |
+
server_start = None
|
114 |
+
_server_command_signal = threading.Event()
|
115 |
+
_server_command: Optional[str] = None
|
116 |
+
|
117 |
+
@property
|
118 |
+
def need_restart(self) -> bool:
|
119 |
+
# Compatibility getter for need_restart.
|
120 |
+
return self.server_command == "restart"
|
121 |
+
|
122 |
+
@need_restart.setter
|
123 |
+
def need_restart(self, value: bool) -> None:
|
124 |
+
# Compatibility setter for need_restart.
|
125 |
+
if value:
|
126 |
+
self.server_command = "restart"
|
127 |
+
|
128 |
+
@property
|
129 |
+
def server_command(self):
|
130 |
+
return self._server_command
|
131 |
+
|
132 |
+
@server_command.setter
|
133 |
+
def server_command(self, value: Optional[str]) -> None:
|
134 |
+
"""
|
135 |
+
Set the server command to `value` and signal that it's been set.
|
136 |
+
"""
|
137 |
+
self._server_command = value
|
138 |
+
self._server_command_signal.set()
|
139 |
+
|
140 |
+
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
|
141 |
+
"""
|
142 |
+
Wait for server command to get set; return and clear the value and signal.
|
143 |
+
"""
|
144 |
+
if self._server_command_signal.wait(timeout):
|
145 |
+
self._server_command_signal.clear()
|
146 |
+
req = self._server_command
|
147 |
+
self._server_command = None
|
148 |
+
return req
|
149 |
+
return None
|
150 |
+
|
151 |
+
def request_restart(self) -> None:
|
152 |
+
self.interrupt()
|
153 |
+
self.server_command = "restart"
|
154 |
+
log.info("Received restart request")
|
155 |
+
|
156 |
+
def skip(self):
|
157 |
+
self.skipped = True
|
158 |
+
log.info("Received skip request")
|
159 |
+
|
160 |
+
def interrupt(self):
|
161 |
+
self.interrupted = True
|
162 |
+
log.info("Received interrupt request")
|
163 |
+
|
164 |
+
def nextjob(self):
|
165 |
+
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
|
166 |
+
self.do_set_current_image()
|
167 |
+
|
168 |
+
self.job_no += 1
|
169 |
+
self.sampling_step = 0
|
170 |
+
self.current_image_sampling_step = 0
|
171 |
+
|
172 |
+
def dict(self):
|
173 |
+
obj = {
|
174 |
+
"skipped": self.skipped,
|
175 |
+
"interrupted": self.interrupted,
|
176 |
+
"job": self.job,
|
177 |
+
"job_count": self.job_count,
|
178 |
+
"job_timestamp": self.job_timestamp,
|
179 |
+
"job_no": self.job_no,
|
180 |
+
"sampling_step": self.sampling_step,
|
181 |
+
"sampling_steps": self.sampling_steps,
|
182 |
+
}
|
183 |
+
|
184 |
+
return obj
|
185 |
+
|
186 |
+
def begin(self, job: str = "(unknown)"):
|
187 |
+
self.sampling_step = 0
|
188 |
+
self.job_count = -1
|
189 |
+
self.processing_has_refined_job_count = False
|
190 |
+
self.job_no = 0
|
191 |
+
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
192 |
+
self.current_latent = None
|
193 |
+
self.current_image = None
|
194 |
+
self.current_image_sampling_step = 0
|
195 |
+
self.id_live_preview = 0
|
196 |
+
self.skipped = False
|
197 |
+
self.interrupted = False
|
198 |
+
self.textinfo = None
|
199 |
+
self.time_start = time.time()
|
200 |
+
self.job = job
|
201 |
+
devices.torch_gc()
|
202 |
+
log.info("Starting job %s", job)
|
203 |
+
|
204 |
+
def end(self):
|
205 |
+
duration = time.time() - self.time_start
|
206 |
+
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
207 |
+
self.job = ""
|
208 |
+
self.job_count = 0
|
209 |
+
|
210 |
+
devices.torch_gc()
|
211 |
+
|
212 |
+
def set_current_image(self):
|
213 |
+
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
214 |
+
if not parallel_processing_allowed:
|
215 |
+
return
|
216 |
+
|
217 |
+
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
|
218 |
+
self.do_set_current_image()
|
219 |
+
|
220 |
+
def do_set_current_image(self):
|
221 |
+
if self.current_latent is None:
|
222 |
+
return
|
223 |
+
|
224 |
+
import modules.sd_samplers
|
225 |
+
if opts.show_progress_grid:
|
226 |
+
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
227 |
+
else:
|
228 |
+
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
229 |
+
|
230 |
+
self.current_image_sampling_step = self.sampling_step
|
231 |
+
|
232 |
+
def assign_current_image(self, image):
|
233 |
+
self.current_image = image
|
234 |
+
self.id_live_preview += 1
|
235 |
+
|
236 |
+
|
237 |
+
state = State()
|
238 |
+
state.server_start = time.time()
|
239 |
+
|
240 |
+
styles_filename = cmd_opts.styles_file
|
241 |
+
prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
242 |
+
|
243 |
+
interrogator = modules.interrogate.InterrogateModels("interrogate")
|
244 |
+
|
245 |
+
face_restorers = []
|
246 |
+
|
247 |
+
|
248 |
+
class OptionInfo:
|
249 |
+
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
|
250 |
+
self.default = default
|
251 |
+
self.label = label
|
252 |
+
self.component = component
|
253 |
+
self.component_args = component_args
|
254 |
+
self.onchange = onchange
|
255 |
+
self.section = section
|
256 |
+
self.refresh = refresh
|
257 |
+
|
258 |
+
self.comment_before = comment_before
|
259 |
+
"""HTML text that will be added after label in UI"""
|
260 |
+
|
261 |
+
self.comment_after = comment_after
|
262 |
+
"""HTML text that will be added before label in UI"""
|
263 |
+
|
264 |
+
def link(self, label, url):
|
265 |
+
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
|
266 |
+
return self
|
267 |
+
|
268 |
+
def js(self, label, js_func):
|
269 |
+
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
|
270 |
+
return self
|
271 |
+
|
272 |
+
def info(self, info):
|
273 |
+
self.comment_after += f"<span class='info'>({info})</span>"
|
274 |
+
return self
|
275 |
+
|
276 |
+
def html(self, html):
|
277 |
+
self.comment_after += html
|
278 |
+
return self
|
279 |
+
|
280 |
+
def needs_restart(self):
|
281 |
+
self.comment_after += " <span class='info'>(requires restart)</span>"
|
282 |
+
return self
|
283 |
+
|
284 |
+
|
285 |
+
|
286 |
+
|
287 |
+
def options_section(section_identifier, options_dict):
|
288 |
+
for v in options_dict.values():
|
289 |
+
v.section = section_identifier
|
290 |
+
|
291 |
+
return options_dict
|
292 |
+
|
293 |
+
|
294 |
+
def list_checkpoint_tiles():
|
295 |
+
import modules.sd_models
|
296 |
+
return modules.sd_models.checkpoint_tiles()
|
297 |
+
|
298 |
+
|
299 |
+
def refresh_checkpoints():
|
300 |
+
import modules.sd_models
|
301 |
+
return modules.sd_models.list_models()
|
302 |
+
|
303 |
+
|
304 |
+
def list_samplers():
|
305 |
+
import modules.sd_samplers
|
306 |
+
return modules.sd_samplers.all_samplers
|
307 |
+
|
308 |
+
|
309 |
+
|
310 |
+
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
311 |
+
tab_names = []
|
312 |
+
|
313 |
+
options_templates = {}
|
314 |
+
|
315 |
+
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
316 |
+
"samples_save": OptionInfo(True, "Always save all generated images"),
|
317 |
+
"samples_format": OptionInfo('png', 'File format for images'),
|
318 |
+
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
319 |
+
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
320 |
+
|
321 |
+
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
322 |
+
"grid_format": OptionInfo('png', 'File format for grids'),
|
323 |
+
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
324 |
+
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
325 |
+
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
326 |
+
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
327 |
+
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
328 |
+
"font": OptionInfo("", "Font for image grids that have text"),
|
329 |
+
"grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
|
330 |
+
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
|
331 |
+
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
|
332 |
+
|
333 |
+
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
334 |
+
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
335 |
+
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
336 |
+
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
337 |
+
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
338 |
+
"save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
|
339 |
+
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
340 |
+
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
341 |
+
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
342 |
+
"export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
|
343 |
+
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
344 |
+
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
345 |
+
"img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
|
346 |
+
|
347 |
+
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
348 |
+
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
349 |
+
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
350 |
+
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
|
351 |
+
|
352 |
+
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
353 |
+
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
354 |
+
|
355 |
+
}))
|
356 |
+
|
357 |
+
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
358 |
+
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
359 |
+
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
360 |
+
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
361 |
+
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
|
362 |
+
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
|
363 |
+
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
364 |
+
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
365 |
+
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
366 |
+
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
367 |
+
}))
|
368 |
+
|
369 |
+
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
370 |
+
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
371 |
+
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
372 |
+
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
373 |
+
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
374 |
+
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
375 |
+
}))
|
376 |
+
|
377 |
+
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
378 |
+
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
379 |
+
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
380 |
+
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
381 |
+
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
382 |
+
}))
|
383 |
+
|
384 |
+
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
385 |
+
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
386 |
+
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
387 |
+
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
388 |
+
}))
|
389 |
+
|
390 |
+
options_templates.update(options_section(('system', "System"), {
|
391 |
+
"show_warnings": OptionInfo(False, "Show warnings in console."),
|
392 |
+
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
393 |
+
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
394 |
+
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
395 |
+
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
396 |
+
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
397 |
+
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
398 |
+
}))
|
399 |
+
|
400 |
+
options_templates.update(options_section(('training', "Training"), {
|
401 |
+
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
402 |
+
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
403 |
+
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
404 |
+
"save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
|
405 |
+
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
406 |
+
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
407 |
+
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
408 |
+
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
409 |
+
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
410 |
+
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
|
411 |
+
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
|
412 |
+
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
413 |
+
}))
|
414 |
+
|
415 |
+
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
416 |
+
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
417 |
+
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
418 |
+
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
419 |
+
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
420 |
+
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
421 |
+
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
422 |
+
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
423 |
+
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
424 |
+
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
425 |
+
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
426 |
+
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
427 |
+
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
428 |
+
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
429 |
+
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
430 |
+
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
431 |
+
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
432 |
+
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
433 |
+
"sd_max_resolution": OptionInfo(2048, "Max resolution output for txt2img and img2img"),
|
434 |
+
"ignore_overrides": OptionInfo([], "Ignore Overrides", gr.CheckboxGroup, lambda: {"choices": [x[0] for x in infotext_to_setting_name_mapping]}),
|
435 |
+
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
436 |
+
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
437 |
+
}))
|
438 |
+
|
439 |
+
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
440 |
+
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
441 |
+
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
442 |
+
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
443 |
+
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
444 |
+
}))
|
445 |
+
|
446 |
+
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
447 |
+
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
448 |
+
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
449 |
+
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
450 |
+
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
451 |
+
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
452 |
+
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
453 |
+
"experimental_persistent_cond_cache": OptionInfo(False, "persistent cond cache").info("Experimental, keep cond caches across jobs, reduce overhead."),
|
454 |
+
}))
|
455 |
+
|
456 |
+
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
457 |
+
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
458 |
+
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
459 |
+
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
460 |
+
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
461 |
+
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
462 |
+
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
463 |
+
}))
|
464 |
+
|
465 |
+
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
466 |
+
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
467 |
+
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
468 |
+
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
469 |
+
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
470 |
+
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
471 |
+
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
|
472 |
+
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
|
473 |
+
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
474 |
+
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
|
475 |
+
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
|
476 |
+
"deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
|
477 |
+
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
478 |
+
}))
|
479 |
+
|
480 |
+
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
481 |
+
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
482 |
+
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
483 |
+
#"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
484 |
+
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
485 |
+
#"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
486 |
+
#"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
487 |
+
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
488 |
+
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
489 |
+
|
490 |
+
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
491 |
+
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
|
492 |
+
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
493 |
+
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
|
494 |
+
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
|
495 |
+
"extra_networks_default_visibility": OptionInfo(True, "Extra Networks default visibility"),
|
496 |
+
"extra_networks_cards_size": OptionInfo(1, "Card size for extra networks", gr.Slider, {"minimum": 0.8, "maximum": 2, "step": 0.1}),
|
497 |
+
"extra_networks_cards_visible_rows": OptionInfo(1, "Visible card rows for extra networks", gr.Slider, {"minimum": 1, "maximum": 3, "step": 1}),
|
498 |
+
"extra_networks_aside": OptionInfo(True, "Extra Networks aside view"),
|
499 |
+
}))
|
500 |
+
|
501 |
+
options_templates.update(options_section(('ui', "User interface"), {
|
502 |
+
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
|
503 |
+
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
|
504 |
+
"img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
|
505 |
+
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
506 |
+
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
507 |
+
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
508 |
+
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
509 |
+
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
510 |
+
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
511 |
+
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
512 |
+
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
513 |
+
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
|
514 |
+
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
515 |
+
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
516 |
+
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
|
517 |
+
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(),
|
518 |
+
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
519 |
+
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
520 |
+
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
521 |
+
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
522 |
+
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
|
523 |
+
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
524 |
+
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
525 |
+
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
|
526 |
+
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
|
527 |
+
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
|
528 |
+
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings"),
|
529 |
+
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
530 |
+
"ui_hidden_tabs": OptionInfo("", "Hidden Tabs"),
|
531 |
+
"ui_header_tabs": OptionInfo("", "Header Tabs"),
|
532 |
+
"ui_views_order": OptionInfo("row-reverse", "Interface order input/parameters | output/preview", gr.Radio, {"choices": ["row", "row-reverse"]}),
|
533 |
+
"ui_output_image_fit": OptionInfo("Scale-down", "Generated image fit method", gr.Radio, {"choices": ["Scale-down", "Contain"]}),
|
534 |
+
"ui_show_range_ticks": OptionInfo(True, "Show ticks for range sliders"),
|
535 |
+
"ui_dispatch_input_release": OptionInfo(True, "Dispatch event change on release, for slider and input number components"),
|
536 |
+
"ui_no_slider_layout": OptionInfo(False, "No sliders compact layout mode"),
|
537 |
+
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
|
538 |
+
}))
|
539 |
+
|
540 |
+
options_templates.update(options_section(('infotext', "Infotext"), {
|
541 |
+
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
542 |
+
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
543 |
+
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
544 |
+
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
545 |
+
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
546 |
+
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
547 |
+
<li>Ignore: keep prompt and styles dropdown as it is.</li>
|
548 |
+
<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
|
549 |
+
<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
|
550 |
+
<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
|
551 |
+
</ul>"""),
|
552 |
+
|
553 |
+
}))
|
554 |
+
|
555 |
+
options_templates.update(options_section(('ui', "Live previews"), {
|
556 |
+
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
557 |
+
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
558 |
+
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
559 |
+
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
560 |
+
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
|
561 |
+
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
|
562 |
+
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
563 |
+
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
564 |
+
"live_preview_image_fit": OptionInfo("Scale-down", "Live preview image fit method", gr.Radio, {"choices": ["Scale-down", "Contain"]}),
|
565 |
+
}))
|
566 |
+
|
567 |
+
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
568 |
+
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(),
|
569 |
+
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
|
570 |
+
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
571 |
+
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
572 |
+
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
573 |
+
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
574 |
+
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
575 |
+
'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
576 |
+
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
577 |
+
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
|
578 |
+
'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
|
579 |
+
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
580 |
+
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
581 |
+
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
582 |
+
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
|
583 |
+
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
|
584 |
+
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
|
585 |
+
}))
|
586 |
+
|
587 |
+
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
588 |
+
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
589 |
+
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
590 |
+
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
591 |
+
}))
|
592 |
+
|
593 |
+
options_templates.update(options_section((None, "Hidden options"), {
|
594 |
+
"disabled_extensions": OptionInfo([], "Disable these extensions"),
|
595 |
+
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
|
596 |
+
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
597 |
+
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
598 |
+
}))
|
599 |
+
|
600 |
+
|
601 |
+
options_templates.update()
|
602 |
+
|
603 |
+
|
604 |
+
class Options:
|
605 |
+
data = None
|
606 |
+
data_labels = options_templates
|
607 |
+
typemap = {int: float}
|
608 |
+
|
609 |
+
def __init__(self):
|
610 |
+
self.data = {k: v.default for k, v in self.data_labels.items()}
|
611 |
+
|
612 |
+
def __setattr__(self, key, value):
|
613 |
+
if self.data is not None:
|
614 |
+
if key in self.data or key in self.data_labels:
|
615 |
+
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
616 |
+
|
617 |
+
info = opts.data_labels.get(key, None)
|
618 |
+
comp_args = info.component_args if info else None
|
619 |
+
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
620 |
+
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
621 |
+
|
622 |
+
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
|
623 |
+
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
624 |
+
|
625 |
+
self.data[key] = value
|
626 |
+
return
|
627 |
+
|
628 |
+
return super(Options, self).__setattr__(key, value)
|
629 |
+
|
630 |
+
def __getattr__(self, item):
|
631 |
+
if self.data is not None:
|
632 |
+
if item in self.data:
|
633 |
+
return self.data[item]
|
634 |
+
|
635 |
+
if item in self.data_labels:
|
636 |
+
return self.data_labels[item].default
|
637 |
+
|
638 |
+
return super(Options, self).__getattribute__(item)
|
639 |
+
|
640 |
+
def set(self, key, value):
|
641 |
+
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
|
642 |
+
|
643 |
+
oldval = self.data.get(key, None)
|
644 |
+
if oldval == value:
|
645 |
+
return False
|
646 |
+
|
647 |
+
try:
|
648 |
+
setattr(self, key, value)
|
649 |
+
except RuntimeError:
|
650 |
+
return False
|
651 |
+
|
652 |
+
if self.data_labels[key].onchange is not None:
|
653 |
+
try:
|
654 |
+
self.data_labels[key].onchange()
|
655 |
+
except Exception as e:
|
656 |
+
errors.display(e, f"changing setting {key} to {value}")
|
657 |
+
setattr(self, key, oldval)
|
658 |
+
return False
|
659 |
+
|
660 |
+
return True
|
661 |
+
|
662 |
+
def get_default(self, key):
|
663 |
+
"""returns the default value for the key"""
|
664 |
+
|
665 |
+
data_label = self.data_labels.get(key)
|
666 |
+
if data_label is None:
|
667 |
+
return None
|
668 |
+
|
669 |
+
return data_label.default
|
670 |
+
|
671 |
+
def save(self, filename):
|
672 |
+
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
673 |
+
|
674 |
+
with open(filename, "w", encoding="utf8") as file:
|
675 |
+
json.dump(self.data, file, indent=4)
|
676 |
+
|
677 |
+
def same_type(self, x, y):
|
678 |
+
if x is None or y is None:
|
679 |
+
return True
|
680 |
+
|
681 |
+
type_x = self.typemap.get(type(x), type(x))
|
682 |
+
type_y = self.typemap.get(type(y), type(y))
|
683 |
+
|
684 |
+
return type_x == type_y
|
685 |
+
|
686 |
+
def load(self, filename):
|
687 |
+
with open(filename, "r", encoding="utf8") as file:
|
688 |
+
self.data = json.load(file)
|
689 |
+
|
690 |
+
# 1.1.1 quicksettings list migration
|
691 |
+
if self.data.get('quicksettings') is not None:
|
692 |
+
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
693 |
+
|
694 |
+
# 1.4.0 ui_reorder
|
695 |
+
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
|
696 |
+
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
|
697 |
+
|
698 |
+
bad_settings = 0
|
699 |
+
for k, v in self.data.items():
|
700 |
+
info = self.data_labels.get(k, None)
|
701 |
+
if info is not None and not self.same_type(info.default, v):
|
702 |
+
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
|
703 |
+
bad_settings += 1
|
704 |
+
|
705 |
+
if bad_settings > 0:
|
706 |
+
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
|
707 |
+
|
708 |
+
def onchange(self, key, func, call=True):
|
709 |
+
item = self.data_labels.get(key)
|
710 |
+
item.onchange = func
|
711 |
+
|
712 |
+
if call:
|
713 |
+
func()
|
714 |
+
|
715 |
+
def dumpjson(self):
|
716 |
+
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
717 |
+
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
718 |
+
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
719 |
+
return json.dumps(d)
|
720 |
+
|
721 |
+
def add_option(self, key, info):
|
722 |
+
self.data_labels[key] = info
|
723 |
+
|
724 |
+
def reorder(self):
|
725 |
+
"""reorder settings so that all items related to section always go together"""
|
726 |
+
|
727 |
+
section_ids = {}
|
728 |
+
settings_items = self.data_labels.items()
|
729 |
+
for _, item in settings_items:
|
730 |
+
if item.section not in section_ids:
|
731 |
+
section_ids[item.section] = len(section_ids)
|
732 |
+
|
733 |
+
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
734 |
+
|
735 |
+
def cast_value(self, key, value):
|
736 |
+
"""casts an arbitrary to the same type as this setting's value with key
|
737 |
+
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
|
738 |
+
"""
|
739 |
+
|
740 |
+
if value is None:
|
741 |
+
return None
|
742 |
+
|
743 |
+
default_value = self.data_labels[key].default
|
744 |
+
if default_value is None:
|
745 |
+
default_value = getattr(self, key, None)
|
746 |
+
if default_value is None:
|
747 |
+
return None
|
748 |
+
|
749 |
+
expected_type = type(default_value)
|
750 |
+
if expected_type == bool and value == "False":
|
751 |
+
value = False
|
752 |
+
else:
|
753 |
+
value = expected_type(value)
|
754 |
+
|
755 |
+
return value
|
756 |
+
|
757 |
+
|
758 |
+
opts = Options()
|
759 |
+
if os.path.exists(config_filename):
|
760 |
+
opts.load(config_filename)
|
761 |
+
|
762 |
+
|
763 |
+
class Shared(sys.modules[__name__].__class__):
|
764 |
+
"""
|
765 |
+
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
|
766 |
+
at program startup.
|
767 |
+
"""
|
768 |
+
|
769 |
+
sd_model_val = None
|
770 |
+
|
771 |
+
@property
|
772 |
+
def sd_model(self):
|
773 |
+
import modules.sd_models
|
774 |
+
|
775 |
+
return modules.sd_models.model_data.get_sd_model()
|
776 |
+
|
777 |
+
@sd_model.setter
|
778 |
+
def sd_model(self, value):
|
779 |
+
import modules.sd_models
|
780 |
+
|
781 |
+
modules.sd_models.model_data.set_sd_model(value)
|
782 |
+
|
783 |
+
|
784 |
+
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
|
785 |
+
sys.modules[__name__].__class__ = Shared
|
786 |
+
|
787 |
+
settings_components = None
|
788 |
+
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
|
789 |
+
|
790 |
+
latent_upscale_default_mode = "Latent"
|
791 |
+
latent_upscale_modes = {
|
792 |
+
"Latent": {"mode": "bilinear", "antialias": False},
|
793 |
+
"Latent (antialiased)": {"mode": "bilinear", "antialias": True},
|
794 |
+
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
|
795 |
+
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
|
796 |
+
"Latent (nearest)": {"mode": "nearest", "antialias": False},
|
797 |
+
"Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False},
|
798 |
+
}
|
799 |
+
|
800 |
+
sd_upscalers = []
|
801 |
+
|
802 |
+
clip_model = None
|
803 |
+
|
804 |
+
progress_print_out = sys.stdout
|
805 |
+
|
806 |
+
gradio_theme = gr.themes.Base()
|
807 |
+
|
808 |
+
|
809 |
+
def reload_gradio_theme(theme_name=None):
|
810 |
+
global gradio_theme
|
811 |
+
if not theme_name:
|
812 |
+
theme_name = opts.gradio_theme
|
813 |
+
|
814 |
+
default_theme_args = dict(
|
815 |
+
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
816 |
+
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
817 |
+
)
|
818 |
+
|
819 |
+
if theme_name == "Default":
|
820 |
+
gradio_theme = gr.themes.Default(**default_theme_args)
|
821 |
+
else:
|
822 |
+
try:
|
823 |
+
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
824 |
+
except Exception as e:
|
825 |
+
errors.display(e, "changing gradio theme")
|
826 |
+
gradio_theme = gr.themes.Default(**default_theme_args)
|
827 |
+
|
828 |
+
|
829 |
+
|
830 |
+
class TotalTQDM:
|
831 |
+
def __init__(self):
|
832 |
+
self._tqdm = None
|
833 |
+
|
834 |
+
def reset(self):
|
835 |
+
self._tqdm = tqdm.tqdm(
|
836 |
+
desc="Total progress",
|
837 |
+
total=state.job_count * state.sampling_steps,
|
838 |
+
position=1,
|
839 |
+
file=progress_print_out
|
840 |
+
)
|
841 |
+
|
842 |
+
def update(self):
|
843 |
+
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
844 |
+
return
|
845 |
+
if self._tqdm is None:
|
846 |
+
self.reset()
|
847 |
+
self._tqdm.update()
|
848 |
+
|
849 |
+
def updateTotal(self, new_total):
|
850 |
+
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
851 |
+
return
|
852 |
+
if self._tqdm is None:
|
853 |
+
self.reset()
|
854 |
+
self._tqdm.total = new_total
|
855 |
+
|
856 |
+
def clear(self):
|
857 |
+
if self._tqdm is not None:
|
858 |
+
self._tqdm.refresh()
|
859 |
+
self._tqdm.close()
|
860 |
+
self._tqdm = None
|
861 |
+
|
862 |
+
|
863 |
+
total_tqdm = TotalTQDM()
|
864 |
+
|
865 |
+
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
|
866 |
+
mem_mon.start()
|
867 |
+
|
868 |
+
|
869 |
+
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
|
870 |
+
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
|
871 |
+
|
872 |
+
|
873 |
+
def listfiles(dirname):
|
874 |
+
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
|
875 |
+
return [file for file in filenames if os.path.isfile(file)]
|
876 |
+
|
877 |
+
|
878 |
+
def html_path(filename):
|
879 |
+
return os.path.join(script_path, "html", filename)
|
880 |
+
|
881 |
+
|
882 |
+
def html(filename):
|
883 |
+
path = html_path(filename)
|
884 |
+
|
885 |
+
if os.path.exists(path):
|
886 |
+
with open(path, encoding="utf8") as file:
|
887 |
+
return file.read()
|
888 |
+
|
889 |
+
return ""
|
890 |
+
|
891 |
+
|
892 |
+
def walk_files(path, allowed_extensions=None):
|
893 |
+
if not os.path.exists(path):
|
894 |
+
return
|
895 |
+
|
896 |
+
if allowed_extensions is not None:
|
897 |
+
allowed_extensions = set(allowed_extensions)
|
898 |
+
|
899 |
+
items = list(os.walk(path, followlinks=True))
|
900 |
+
items = sorted(items, key=lambda x: natural_sort_key(x[0]))
|
901 |
+
|
902 |
+
for root, _, files in items:
|
903 |
+
for filename in sorted(files, key=natural_sort_key):
|
904 |
+
if allowed_extensions is not None:
|
905 |
+
_, ext = os.path.splitext(filename)
|
906 |
+
if ext not in allowed_extensions:
|
907 |
+
continue
|
908 |
+
|
909 |
+
if not opts.list_hidden_files and ("/." in root or "\\." in root):
|
910 |
+
continue
|
911 |
+
|
912 |
+
yield os.path.join(root, filename)
|
modules/shared_items.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
def realesrgan_models_names():
|
4 |
+
import modules.realesrgan_model
|
5 |
+
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
6 |
+
|
7 |
+
|
8 |
+
def postprocessing_scripts():
|
9 |
+
import modules.scripts
|
10 |
+
|
11 |
+
return modules.scripts.scripts_postproc.scripts
|
12 |
+
|
13 |
+
|
14 |
+
def sd_vae_items():
|
15 |
+
import modules.sd_vae
|
16 |
+
|
17 |
+
return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
|
18 |
+
|
19 |
+
|
20 |
+
def refresh_vae_list():
|
21 |
+
import modules.sd_vae
|
22 |
+
|
23 |
+
modules.sd_vae.refresh_vae_list()
|
24 |
+
|
25 |
+
|
26 |
+
def cross_attention_optimizations():
|
27 |
+
import modules.sd_hijack
|
28 |
+
|
29 |
+
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
|
30 |
+
|
31 |
+
|
32 |
+
def sd_unet_items():
|
33 |
+
import modules.sd_unet
|
34 |
+
|
35 |
+
return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"]
|
36 |
+
|
37 |
+
|
38 |
+
def refresh_unet_list():
|
39 |
+
import modules.sd_unet
|
40 |
+
|
41 |
+
modules.sd_unet.list_unets()
|
42 |
+
|
43 |
+
|
44 |
+
ui_reorder_categories_builtin_items = [
|
45 |
+
"inpaint",
|
46 |
+
"sampler",
|
47 |
+
"checkboxes",
|
48 |
+
"hires_fix",
|
49 |
+
"dimensions",
|
50 |
+
"cfg",
|
51 |
+
"seed",
|
52 |
+
"batch",
|
53 |
+
"override_settings",
|
54 |
+
]
|
55 |
+
|
56 |
+
|
57 |
+
def ui_reorder_categories():
|
58 |
+
from modules import scripts
|
59 |
+
|
60 |
+
yield from ui_reorder_categories_builtin_items
|
61 |
+
|
62 |
+
sections = {}
|
63 |
+
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
|
64 |
+
if isinstance(script.section, str):
|
65 |
+
sections[script.section] = 1
|
66 |
+
|
67 |
+
yield from sections
|
68 |
+
|
69 |
+
yield "scripts"
|
modules/styles.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import os
|
3 |
+
import os.path
|
4 |
+
import re
|
5 |
+
import typing
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
|
9 |
+
class PromptStyle(typing.NamedTuple):
|
10 |
+
name: str
|
11 |
+
prompt: str
|
12 |
+
negative_prompt: str
|
13 |
+
|
14 |
+
|
15 |
+
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
16 |
+
if "{prompt}" in style_prompt:
|
17 |
+
res = style_prompt.replace("{prompt}", prompt)
|
18 |
+
else:
|
19 |
+
parts = filter(None, (prompt.strip(), style_prompt.strip()))
|
20 |
+
res = ", ".join(parts)
|
21 |
+
|
22 |
+
return res
|
23 |
+
|
24 |
+
|
25 |
+
def apply_styles_to_prompt(prompt, styles):
|
26 |
+
for style in styles:
|
27 |
+
prompt = merge_prompts(style, prompt)
|
28 |
+
|
29 |
+
return prompt
|
30 |
+
|
31 |
+
|
32 |
+
re_spaces = re.compile(" +")
|
33 |
+
|
34 |
+
|
35 |
+
def extract_style_text_from_prompt(style_text, prompt):
|
36 |
+
stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
|
37 |
+
stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
|
38 |
+
if "{prompt}" in stripped_style_text:
|
39 |
+
left, right = stripped_style_text.split("{prompt}", 2)
|
40 |
+
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
41 |
+
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
|
42 |
+
return True, prompt
|
43 |
+
else:
|
44 |
+
if stripped_prompt.endswith(stripped_style_text):
|
45 |
+
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
|
46 |
+
|
47 |
+
if prompt.endswith(', '):
|
48 |
+
prompt = prompt[:-2]
|
49 |
+
|
50 |
+
return True, prompt
|
51 |
+
|
52 |
+
return False, prompt
|
53 |
+
|
54 |
+
|
55 |
+
def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
|
56 |
+
if not style.prompt and not style.negative_prompt:
|
57 |
+
return False, prompt, negative_prompt
|
58 |
+
|
59 |
+
match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
|
60 |
+
if not match_positive:
|
61 |
+
return False, prompt, negative_prompt
|
62 |
+
|
63 |
+
match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
|
64 |
+
if not match_negative:
|
65 |
+
return False, prompt, negative_prompt
|
66 |
+
|
67 |
+
return True, extracted_positive, extracted_negative
|
68 |
+
|
69 |
+
|
70 |
+
class StyleDatabase:
|
71 |
+
def __init__(self, path: str):
|
72 |
+
self.no_style = PromptStyle("None", "", "")
|
73 |
+
self.styles = {}
|
74 |
+
self.path = path
|
75 |
+
|
76 |
+
self.reload()
|
77 |
+
|
78 |
+
def reload(self):
|
79 |
+
self.styles.clear()
|
80 |
+
|
81 |
+
if not os.path.exists(self.path):
|
82 |
+
return
|
83 |
+
|
84 |
+
with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
|
85 |
+
reader = csv.DictReader(file, skipinitialspace=True)
|
86 |
+
for row in reader:
|
87 |
+
# Support loading old CSV format with "name, text"-columns
|
88 |
+
prompt = row["prompt"] if "prompt" in row else row["text"]
|
89 |
+
negative_prompt = row.get("negative_prompt", "")
|
90 |
+
self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
|
91 |
+
|
92 |
+
def get_style_prompts(self, styles):
|
93 |
+
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
94 |
+
|
95 |
+
def get_negative_style_prompts(self, styles):
|
96 |
+
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
|
97 |
+
|
98 |
+
def apply_styles_to_prompt(self, prompt, styles):
|
99 |
+
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
|
100 |
+
|
101 |
+
def apply_negative_styles_to_prompt(self, prompt, styles):
|
102 |
+
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
|
103 |
+
|
104 |
+
def save_styles(self, path: str) -> None:
|
105 |
+
# Always keep a backup file around
|
106 |
+
if os.path.exists(path):
|
107 |
+
shutil.copy(path, f"{path}.bak")
|
108 |
+
|
109 |
+
fd = os.open(path, os.O_RDWR | os.O_CREAT)
|
110 |
+
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
111 |
+
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
112 |
+
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
113 |
+
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
114 |
+
writer.writeheader()
|
115 |
+
writer.writerows(style._asdict() for k, style in self.styles.items())
|
116 |
+
|
117 |
+
def extract_styles_from_prompt(self, prompt, negative_prompt):
|
118 |
+
extracted = []
|
119 |
+
|
120 |
+
applicable_styles = list(self.styles.values())
|
121 |
+
|
122 |
+
while True:
|
123 |
+
found_style = None
|
124 |
+
|
125 |
+
for style in applicable_styles:
|
126 |
+
is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
|
127 |
+
if is_match:
|
128 |
+
found_style = style
|
129 |
+
prompt = new_prompt
|
130 |
+
negative_prompt = new_neg_prompt
|
131 |
+
break
|
132 |
+
|
133 |
+
if not found_style:
|
134 |
+
break
|
135 |
+
|
136 |
+
applicable_styles.remove(found_style)
|
137 |
+
extracted.append(found_style.name)
|
138 |
+
|
139 |
+
return list(reversed(extracted)), prompt, negative_prompt
|
modules/sub_quadratic_attention.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# original source:
|
2 |
+
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
|
3 |
+
# license:
|
4 |
+
# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
|
5 |
+
# credit:
|
6 |
+
# Amin Rezaei (original author)
|
7 |
+
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
|
8 |
+
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
|
9 |
+
# implementation of:
|
10 |
+
# Self-attention Does Not Need O(n2) Memory":
|
11 |
+
# https://arxiv.org/abs/2112.05682v2
|
12 |
+
|
13 |
+
from functools import partial
|
14 |
+
import torch
|
15 |
+
from torch import Tensor
|
16 |
+
from torch.utils.checkpoint import checkpoint
|
17 |
+
import math
|
18 |
+
from typing import Optional, NamedTuple, List
|
19 |
+
|
20 |
+
|
21 |
+
def narrow_trunc(
|
22 |
+
input: Tensor,
|
23 |
+
dim: int,
|
24 |
+
start: int,
|
25 |
+
length: int
|
26 |
+
) -> Tensor:
|
27 |
+
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
|
28 |
+
|
29 |
+
|
30 |
+
class AttnChunk(NamedTuple):
|
31 |
+
exp_values: Tensor
|
32 |
+
exp_weights_sum: Tensor
|
33 |
+
max_score: Tensor
|
34 |
+
|
35 |
+
|
36 |
+
class SummarizeChunk:
|
37 |
+
@staticmethod
|
38 |
+
def __call__(
|
39 |
+
query: Tensor,
|
40 |
+
key: Tensor,
|
41 |
+
value: Tensor,
|
42 |
+
) -> AttnChunk: ...
|
43 |
+
|
44 |
+
|
45 |
+
class ComputeQueryChunkAttn:
|
46 |
+
@staticmethod
|
47 |
+
def __call__(
|
48 |
+
query: Tensor,
|
49 |
+
key: Tensor,
|
50 |
+
value: Tensor,
|
51 |
+
) -> Tensor: ...
|
52 |
+
|
53 |
+
|
54 |
+
def _summarize_chunk(
|
55 |
+
query: Tensor,
|
56 |
+
key: Tensor,
|
57 |
+
value: Tensor,
|
58 |
+
scale: float,
|
59 |
+
) -> AttnChunk:
|
60 |
+
attn_weights = torch.baddbmm(
|
61 |
+
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
62 |
+
query,
|
63 |
+
key.transpose(1,2),
|
64 |
+
alpha=scale,
|
65 |
+
beta=0,
|
66 |
+
)
|
67 |
+
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
68 |
+
max_score = max_score.detach()
|
69 |
+
exp_weights = torch.exp(attn_weights - max_score)
|
70 |
+
exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
|
71 |
+
max_score = max_score.squeeze(-1)
|
72 |
+
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
73 |
+
|
74 |
+
|
75 |
+
def _query_chunk_attention(
|
76 |
+
query: Tensor,
|
77 |
+
key: Tensor,
|
78 |
+
value: Tensor,
|
79 |
+
summarize_chunk: SummarizeChunk,
|
80 |
+
kv_chunk_size: int,
|
81 |
+
) -> Tensor:
|
82 |
+
batch_x_heads, k_tokens, k_channels_per_head = key.shape
|
83 |
+
_, _, v_channels_per_head = value.shape
|
84 |
+
|
85 |
+
def chunk_scanner(chunk_idx: int) -> AttnChunk:
|
86 |
+
key_chunk = narrow_trunc(
|
87 |
+
key,
|
88 |
+
1,
|
89 |
+
chunk_idx,
|
90 |
+
kv_chunk_size
|
91 |
+
)
|
92 |
+
value_chunk = narrow_trunc(
|
93 |
+
value,
|
94 |
+
1,
|
95 |
+
chunk_idx,
|
96 |
+
kv_chunk_size
|
97 |
+
)
|
98 |
+
return summarize_chunk(query, key_chunk, value_chunk)
|
99 |
+
|
100 |
+
chunks: List[AttnChunk] = [
|
101 |
+
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
102 |
+
]
|
103 |
+
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
104 |
+
chunk_values, chunk_weights, chunk_max = acc_chunk
|
105 |
+
|
106 |
+
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
|
107 |
+
max_diffs = torch.exp(chunk_max - global_max)
|
108 |
+
chunk_values *= torch.unsqueeze(max_diffs, -1)
|
109 |
+
chunk_weights *= max_diffs
|
110 |
+
|
111 |
+
all_values = chunk_values.sum(dim=0)
|
112 |
+
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
113 |
+
return all_values / all_weights
|
114 |
+
|
115 |
+
|
116 |
+
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
117 |
+
def _get_attention_scores_no_kv_chunking(
|
118 |
+
query: Tensor,
|
119 |
+
key: Tensor,
|
120 |
+
value: Tensor,
|
121 |
+
scale: float,
|
122 |
+
) -> Tensor:
|
123 |
+
attn_scores = torch.baddbmm(
|
124 |
+
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
125 |
+
query,
|
126 |
+
key.transpose(1,2),
|
127 |
+
alpha=scale,
|
128 |
+
beta=0,
|
129 |
+
)
|
130 |
+
attn_probs = attn_scores.softmax(dim=-1)
|
131 |
+
del attn_scores
|
132 |
+
hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
|
133 |
+
return hidden_states_slice
|
134 |
+
|
135 |
+
|
136 |
+
class ScannedChunk(NamedTuple):
|
137 |
+
chunk_idx: int
|
138 |
+
attn_chunk: AttnChunk
|
139 |
+
|
140 |
+
|
141 |
+
def efficient_dot_product_attention(
|
142 |
+
query: Tensor,
|
143 |
+
key: Tensor,
|
144 |
+
value: Tensor,
|
145 |
+
query_chunk_size=1024,
|
146 |
+
kv_chunk_size: Optional[int] = None,
|
147 |
+
kv_chunk_size_min: Optional[int] = None,
|
148 |
+
use_checkpoint=True,
|
149 |
+
):
|
150 |
+
"""Computes efficient dot-product attention given query, key, and value.
|
151 |
+
This is efficient version of attention presented in
|
152 |
+
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
|
153 |
+
Args:
|
154 |
+
query: queries for calculating attention with shape of
|
155 |
+
`[batch * num_heads, tokens, channels_per_head]`.
|
156 |
+
key: keys for calculating attention with shape of
|
157 |
+
`[batch * num_heads, tokens, channels_per_head]`.
|
158 |
+
value: values to be used in attention with shape of
|
159 |
+
`[batch * num_heads, tokens, channels_per_head]`.
|
160 |
+
query_chunk_size: int: query chunks size
|
161 |
+
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
|
162 |
+
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
|
163 |
+
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
|
164 |
+
Returns:
|
165 |
+
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
|
166 |
+
"""
|
167 |
+
batch_x_heads, q_tokens, q_channels_per_head = query.shape
|
168 |
+
_, k_tokens, _ = key.shape
|
169 |
+
scale = q_channels_per_head ** -0.5
|
170 |
+
|
171 |
+
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
|
172 |
+
if kv_chunk_size_min is not None:
|
173 |
+
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
174 |
+
|
175 |
+
def get_query_chunk(chunk_idx: int) -> Tensor:
|
176 |
+
return narrow_trunc(
|
177 |
+
query,
|
178 |
+
1,
|
179 |
+
chunk_idx,
|
180 |
+
min(query_chunk_size, q_tokens)
|
181 |
+
)
|
182 |
+
|
183 |
+
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
184 |
+
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
185 |
+
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
186 |
+
_get_attention_scores_no_kv_chunking,
|
187 |
+
scale=scale
|
188 |
+
) if k_tokens <= kv_chunk_size else (
|
189 |
+
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
190 |
+
partial(
|
191 |
+
_query_chunk_attention,
|
192 |
+
kv_chunk_size=kv_chunk_size,
|
193 |
+
summarize_chunk=summarize_chunk,
|
194 |
+
)
|
195 |
+
)
|
196 |
+
|
197 |
+
if q_tokens <= query_chunk_size:
|
198 |
+
# fast-path for when there's just 1 query chunk
|
199 |
+
return compute_query_chunk_attn(
|
200 |
+
query=query,
|
201 |
+
key=key,
|
202 |
+
value=value,
|
203 |
+
)
|
204 |
+
|
205 |
+
res = torch.zeros_like(query)
|
206 |
+
for i in range(math.ceil(q_tokens / query_chunk_size)):
|
207 |
+
attn_scores = compute_query_chunk_attn(
|
208 |
+
query=get_query_chunk(i * query_chunk_size),
|
209 |
+
key=key,
|
210 |
+
value=value,
|
211 |
+
)
|
212 |
+
|
213 |
+
res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
|
214 |
+
|
215 |
+
return res
|
modules/sysinfo.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
import platform
|
7 |
+
import hashlib
|
8 |
+
import pkg_resources
|
9 |
+
import psutil
|
10 |
+
import re
|
11 |
+
|
12 |
+
import launch
|
13 |
+
from modules import paths_internal, timer
|
14 |
+
|
15 |
+
checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
|
16 |
+
environment_whitelist = {
|
17 |
+
"GIT",
|
18 |
+
"INDEX_URL",
|
19 |
+
"WEBUI_LAUNCH_LIVE_OUTPUT",
|
20 |
+
"GRADIO_ANALYTICS_ENABLED",
|
21 |
+
"PYTHONPATH",
|
22 |
+
"TORCH_INDEX_URL",
|
23 |
+
"TORCH_COMMAND",
|
24 |
+
"REQS_FILE",
|
25 |
+
"XFORMERS_PACKAGE",
|
26 |
+
"GFPGAN_PACKAGE",
|
27 |
+
"CLIP_PACKAGE",
|
28 |
+
"OPENCLIP_PACKAGE",
|
29 |
+
"STABLE_DIFFUSION_REPO",
|
30 |
+
"K_DIFFUSION_REPO",
|
31 |
+
"CODEFORMER_REPO",
|
32 |
+
"BLIP_REPO",
|
33 |
+
"STABLE_DIFFUSION_COMMIT_HASH",
|
34 |
+
"K_DIFFUSION_COMMIT_HASH",
|
35 |
+
"CODEFORMER_COMMIT_HASH",
|
36 |
+
"BLIP_COMMIT_HASH",
|
37 |
+
"COMMANDLINE_ARGS",
|
38 |
+
"IGNORE_CMD_ARGS_ERRORS",
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def pretty_bytes(num, suffix="B"):
|
43 |
+
for unit in ["", "K", "M", "G", "T", "P", "E", "Z", "Y"]:
|
44 |
+
if abs(num) < 1024 or unit == 'Y':
|
45 |
+
return f"{num:.0f}{unit}{suffix}"
|
46 |
+
num /= 1024
|
47 |
+
|
48 |
+
|
49 |
+
def get():
|
50 |
+
res = get_dict()
|
51 |
+
|
52 |
+
text = json.dumps(res, ensure_ascii=False, indent=4)
|
53 |
+
|
54 |
+
h = hashlib.sha256(text.encode("utf8"))
|
55 |
+
text = text.replace(checksum_token, h.hexdigest())
|
56 |
+
|
57 |
+
return text
|
58 |
+
|
59 |
+
|
60 |
+
re_checksum = re.compile(r'"Checksum": "([0-9a-fA-F]{64})"')
|
61 |
+
|
62 |
+
|
63 |
+
def check(x):
|
64 |
+
m = re.search(re_checksum, x)
|
65 |
+
if not m:
|
66 |
+
return False
|
67 |
+
|
68 |
+
replaced = re.sub(re_checksum, f'"Checksum": "{checksum_token}"', x)
|
69 |
+
|
70 |
+
h = hashlib.sha256(replaced.encode("utf8"))
|
71 |
+
return h.hexdigest() == m.group(1)
|
72 |
+
|
73 |
+
|
74 |
+
def get_dict():
|
75 |
+
ram = psutil.virtual_memory()
|
76 |
+
|
77 |
+
res = {
|
78 |
+
"Platform": platform.platform(),
|
79 |
+
"Python": platform.python_version(),
|
80 |
+
"Version": launch.git_tag(),
|
81 |
+
"Commit": launch.commit_hash(),
|
82 |
+
"Script path": paths_internal.script_path,
|
83 |
+
"Data path": paths_internal.data_path,
|
84 |
+
"Extensions dir": paths_internal.extensions_dir,
|
85 |
+
"Checksum": checksum_token,
|
86 |
+
"Commandline": sys.argv,
|
87 |
+
"Torch env info": get_torch_sysinfo(),
|
88 |
+
"Exceptions": get_exceptions(),
|
89 |
+
"CPU": {
|
90 |
+
"model": platform.processor(),
|
91 |
+
"count logical": psutil.cpu_count(logical=True),
|
92 |
+
"count physical": psutil.cpu_count(logical=False),
|
93 |
+
},
|
94 |
+
"RAM": {
|
95 |
+
x: pretty_bytes(getattr(ram, x, 0)) for x in ["total", "used", "free", "active", "inactive", "buffers", "cached", "shared"] if getattr(ram, x, 0) != 0
|
96 |
+
},
|
97 |
+
"Extensions": get_extensions(enabled=True),
|
98 |
+
"Inactive extensions": get_extensions(enabled=False),
|
99 |
+
"Environment": get_environment(),
|
100 |
+
"Config": get_config(),
|
101 |
+
"Startup": timer.startup_record,
|
102 |
+
"Packages": sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]),
|
103 |
+
}
|
104 |
+
|
105 |
+
return res
|
106 |
+
|
107 |
+
|
108 |
+
def format_traceback(tb):
|
109 |
+
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
110 |
+
|
111 |
+
|
112 |
+
def get_exceptions():
|
113 |
+
try:
|
114 |
+
from modules import errors
|
115 |
+
|
116 |
+
return [{"exception": str(e), "traceback": format_traceback(tb)} for e, tb in reversed(errors.exception_records)]
|
117 |
+
except Exception as e:
|
118 |
+
return str(e)
|
119 |
+
|
120 |
+
|
121 |
+
def get_environment():
|
122 |
+
return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
|
123 |
+
|
124 |
+
|
125 |
+
re_newline = re.compile(r"\r*\n")
|
126 |
+
|
127 |
+
|
128 |
+
def get_torch_sysinfo():
|
129 |
+
try:
|
130 |
+
import torch.utils.collect_env
|
131 |
+
info = torch.utils.collect_env.get_env_info()._asdict()
|
132 |
+
|
133 |
+
return {k: re.split(re_newline, str(v)) if "\n" in str(v) else v for k, v in info.items()}
|
134 |
+
except Exception as e:
|
135 |
+
return str(e)
|
136 |
+
|
137 |
+
|
138 |
+
def get_extensions(*, enabled):
|
139 |
+
|
140 |
+
try:
|
141 |
+
from modules import extensions
|
142 |
+
|
143 |
+
def to_json(x: extensions.Extension):
|
144 |
+
return {
|
145 |
+
"name": x.name,
|
146 |
+
"path": x.path,
|
147 |
+
"version": x.version,
|
148 |
+
"branch": x.branch,
|
149 |
+
"remote": x.remote,
|
150 |
+
}
|
151 |
+
|
152 |
+
return [to_json(x) for x in extensions.extensions if not x.is_builtin and x.enabled == enabled]
|
153 |
+
except Exception as e:
|
154 |
+
return str(e)
|
155 |
+
|
156 |
+
|
157 |
+
def get_config():
|
158 |
+
try:
|
159 |
+
from modules import shared
|
160 |
+
return shared.opts.data
|
161 |
+
except Exception as e:
|
162 |
+
return str(e)
|
modules/textual_inversion/__pycache__/autocrop.cpython-310.pyc
ADDED
Binary file (8.73 kB). View file
|
|
modules/textual_inversion/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (9.71 kB). View file
|
|
modules/textual_inversion/__pycache__/image_embedding.cpython-310.pyc
ADDED
Binary file (7.87 kB). View file
|
|
modules/textual_inversion/__pycache__/learn_schedule.cpython-310.pyc
ADDED
Binary file (2.74 kB). View file
|
|
modules/textual_inversion/__pycache__/logging.cpython-310.pyc
ADDED
Binary file (1.68 kB). View file
|
|
modules/textual_inversion/__pycache__/preprocess.cpython-310.pyc
ADDED
Binary file (7.17 kB). View file
|
|
modules/textual_inversion/__pycache__/textual_inversion.cpython-310.pyc
ADDED
Binary file (20.7 kB). View file
|
|
modules/textual_inversion/__pycache__/ui.cpython-310.pyc
ADDED
Binary file (1.61 kB). View file
|
|
modules/textual_inversion/autocrop.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import requests
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
from PIL import ImageDraw
|
6 |
+
|
7 |
+
GREEN = "#0F0"
|
8 |
+
BLUE = "#00F"
|
9 |
+
RED = "#F00"
|
10 |
+
|
11 |
+
|
12 |
+
def crop_image(im, settings):
|
13 |
+
""" Intelligently crop an image to the subject matter """
|
14 |
+
|
15 |
+
scale_by = 1
|
16 |
+
if is_landscape(im.width, im.height):
|
17 |
+
scale_by = settings.crop_height / im.height
|
18 |
+
elif is_portrait(im.width, im.height):
|
19 |
+
scale_by = settings.crop_width / im.width
|
20 |
+
elif is_square(im.width, im.height):
|
21 |
+
if is_square(settings.crop_width, settings.crop_height):
|
22 |
+
scale_by = settings.crop_width / im.width
|
23 |
+
elif is_landscape(settings.crop_width, settings.crop_height):
|
24 |
+
scale_by = settings.crop_width / im.width
|
25 |
+
elif is_portrait(settings.crop_width, settings.crop_height):
|
26 |
+
scale_by = settings.crop_height / im.height
|
27 |
+
|
28 |
+
|
29 |
+
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
30 |
+
im_debug = im.copy()
|
31 |
+
|
32 |
+
focus = focal_point(im_debug, settings)
|
33 |
+
|
34 |
+
# take the focal point and turn it into crop coordinates that try to center over the focal
|
35 |
+
# point but then get adjusted back into the frame
|
36 |
+
y_half = int(settings.crop_height / 2)
|
37 |
+
x_half = int(settings.crop_width / 2)
|
38 |
+
|
39 |
+
x1 = focus.x - x_half
|
40 |
+
if x1 < 0:
|
41 |
+
x1 = 0
|
42 |
+
elif x1 + settings.crop_width > im.width:
|
43 |
+
x1 = im.width - settings.crop_width
|
44 |
+
|
45 |
+
y1 = focus.y - y_half
|
46 |
+
if y1 < 0:
|
47 |
+
y1 = 0
|
48 |
+
elif y1 + settings.crop_height > im.height:
|
49 |
+
y1 = im.height - settings.crop_height
|
50 |
+
|
51 |
+
x2 = x1 + settings.crop_width
|
52 |
+
y2 = y1 + settings.crop_height
|
53 |
+
|
54 |
+
crop = [x1, y1, x2, y2]
|
55 |
+
|
56 |
+
results = []
|
57 |
+
|
58 |
+
results.append(im.crop(tuple(crop)))
|
59 |
+
|
60 |
+
if settings.annotate_image:
|
61 |
+
d = ImageDraw.Draw(im_debug)
|
62 |
+
rect = list(crop)
|
63 |
+
rect[2] -= 1
|
64 |
+
rect[3] -= 1
|
65 |
+
d.rectangle(rect, outline=GREEN)
|
66 |
+
results.append(im_debug)
|
67 |
+
if settings.destop_view_image:
|
68 |
+
im_debug.show()
|
69 |
+
|
70 |
+
return results
|
71 |
+
|
72 |
+
def focal_point(im, settings):
|
73 |
+
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
74 |
+
entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
|
75 |
+
face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
|
76 |
+
|
77 |
+
pois = []
|
78 |
+
|
79 |
+
weight_pref_total = 0
|
80 |
+
if corner_points:
|
81 |
+
weight_pref_total += settings.corner_points_weight
|
82 |
+
if entropy_points:
|
83 |
+
weight_pref_total += settings.entropy_points_weight
|
84 |
+
if face_points:
|
85 |
+
weight_pref_total += settings.face_points_weight
|
86 |
+
|
87 |
+
corner_centroid = None
|
88 |
+
if corner_points:
|
89 |
+
corner_centroid = centroid(corner_points)
|
90 |
+
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
91 |
+
pois.append(corner_centroid)
|
92 |
+
|
93 |
+
entropy_centroid = None
|
94 |
+
if entropy_points:
|
95 |
+
entropy_centroid = centroid(entropy_points)
|
96 |
+
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
|
97 |
+
pois.append(entropy_centroid)
|
98 |
+
|
99 |
+
face_centroid = None
|
100 |
+
if face_points:
|
101 |
+
face_centroid = centroid(face_points)
|
102 |
+
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
103 |
+
pois.append(face_centroid)
|
104 |
+
|
105 |
+
average_point = poi_average(pois, settings)
|
106 |
+
|
107 |
+
if settings.annotate_image:
|
108 |
+
d = ImageDraw.Draw(im)
|
109 |
+
max_size = min(im.width, im.height) * 0.07
|
110 |
+
if corner_centroid is not None:
|
111 |
+
color = BLUE
|
112 |
+
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
113 |
+
d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
|
114 |
+
d.ellipse(box, outline=color)
|
115 |
+
if len(corner_points) > 1:
|
116 |
+
for f in corner_points:
|
117 |
+
d.rectangle(f.bounding(4), outline=color)
|
118 |
+
if entropy_centroid is not None:
|
119 |
+
color = "#ff0"
|
120 |
+
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
121 |
+
d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
|
122 |
+
d.ellipse(box, outline=color)
|
123 |
+
if len(entropy_points) > 1:
|
124 |
+
for f in entropy_points:
|
125 |
+
d.rectangle(f.bounding(4), outline=color)
|
126 |
+
if face_centroid is not None:
|
127 |
+
color = RED
|
128 |
+
box = face_centroid.bounding(max_size * face_centroid.weight)
|
129 |
+
d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color)
|
130 |
+
d.ellipse(box, outline=color)
|
131 |
+
if len(face_points) > 1:
|
132 |
+
for f in face_points:
|
133 |
+
d.rectangle(f.bounding(4), outline=color)
|
134 |
+
|
135 |
+
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
136 |
+
|
137 |
+
return average_point
|
138 |
+
|
139 |
+
|
140 |
+
def image_face_points(im, settings):
|
141 |
+
if settings.dnn_model_path is not None:
|
142 |
+
detector = cv2.FaceDetectorYN.create(
|
143 |
+
settings.dnn_model_path,
|
144 |
+
"",
|
145 |
+
(im.width, im.height),
|
146 |
+
0.9, # score threshold
|
147 |
+
0.3, # nms threshold
|
148 |
+
5000 # keep top k before nms
|
149 |
+
)
|
150 |
+
faces = detector.detect(np.array(im))
|
151 |
+
results = []
|
152 |
+
if faces[1] is not None:
|
153 |
+
for face in faces[1]:
|
154 |
+
x = face[0]
|
155 |
+
y = face[1]
|
156 |
+
w = face[2]
|
157 |
+
h = face[3]
|
158 |
+
results.append(
|
159 |
+
PointOfInterest(
|
160 |
+
int(x + (w * 0.5)), # face focus left/right is center
|
161 |
+
int(y + (h * 0.33)), # face focus up/down is close to the top of the head
|
162 |
+
size = w,
|
163 |
+
weight = 1/len(faces[1])
|
164 |
+
)
|
165 |
+
)
|
166 |
+
return results
|
167 |
+
else:
|
168 |
+
np_im = np.array(im)
|
169 |
+
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
|
170 |
+
|
171 |
+
tries = [
|
172 |
+
[ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
|
173 |
+
[ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
|
174 |
+
[ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
|
175 |
+
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
|
176 |
+
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
|
177 |
+
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
|
178 |
+
[ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
|
179 |
+
[ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
|
180 |
+
]
|
181 |
+
for t in tries:
|
182 |
+
classifier = cv2.CascadeClassifier(t[0])
|
183 |
+
minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
|
184 |
+
try:
|
185 |
+
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
186 |
+
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
187 |
+
except Exception:
|
188 |
+
continue
|
189 |
+
|
190 |
+
if faces:
|
191 |
+
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
|
192 |
+
return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
|
193 |
+
return []
|
194 |
+
|
195 |
+
|
196 |
+
def image_corner_points(im, settings):
|
197 |
+
grayscale = im.convert("L")
|
198 |
+
|
199 |
+
# naive attempt at preventing focal points from collecting at watermarks near the bottom
|
200 |
+
gd = ImageDraw.Draw(grayscale)
|
201 |
+
gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
|
202 |
+
|
203 |
+
np_im = np.array(grayscale)
|
204 |
+
|
205 |
+
points = cv2.goodFeaturesToTrack(
|
206 |
+
np_im,
|
207 |
+
maxCorners=100,
|
208 |
+
qualityLevel=0.04,
|
209 |
+
minDistance=min(grayscale.width, grayscale.height)*0.06,
|
210 |
+
useHarrisDetector=False,
|
211 |
+
)
|
212 |
+
|
213 |
+
if points is None:
|
214 |
+
return []
|
215 |
+
|
216 |
+
focal_points = []
|
217 |
+
for point in points:
|
218 |
+
x, y = point.ravel()
|
219 |
+
focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
|
220 |
+
|
221 |
+
return focal_points
|
222 |
+
|
223 |
+
|
224 |
+
def image_entropy_points(im, settings):
|
225 |
+
landscape = im.height < im.width
|
226 |
+
portrait = im.height > im.width
|
227 |
+
if landscape:
|
228 |
+
move_idx = [0, 2]
|
229 |
+
move_max = im.size[0]
|
230 |
+
elif portrait:
|
231 |
+
move_idx = [1, 3]
|
232 |
+
move_max = im.size[1]
|
233 |
+
else:
|
234 |
+
return []
|
235 |
+
|
236 |
+
e_max = 0
|
237 |
+
crop_current = [0, 0, settings.crop_width, settings.crop_height]
|
238 |
+
crop_best = crop_current
|
239 |
+
while crop_current[move_idx[1]] < move_max:
|
240 |
+
crop = im.crop(tuple(crop_current))
|
241 |
+
e = image_entropy(crop)
|
242 |
+
|
243 |
+
if (e > e_max):
|
244 |
+
e_max = e
|
245 |
+
crop_best = list(crop_current)
|
246 |
+
|
247 |
+
crop_current[move_idx[0]] += 4
|
248 |
+
crop_current[move_idx[1]] += 4
|
249 |
+
|
250 |
+
x_mid = int(crop_best[0] + settings.crop_width/2)
|
251 |
+
y_mid = int(crop_best[1] + settings.crop_height/2)
|
252 |
+
|
253 |
+
return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
|
254 |
+
|
255 |
+
|
256 |
+
def image_entropy(im):
|
257 |
+
# greyscale image entropy
|
258 |
+
# band = np.asarray(im.convert("L"))
|
259 |
+
band = np.asarray(im.convert("1"), dtype=np.uint8)
|
260 |
+
hist, _ = np.histogram(band, bins=range(0, 256))
|
261 |
+
hist = hist[hist > 0]
|
262 |
+
return -np.log2(hist / hist.sum()).sum()
|
263 |
+
|
264 |
+
|
265 |
+
def centroid(pois):
|
266 |
+
x = [poi.x for poi in pois]
|
267 |
+
y = [poi.y for poi in pois]
|
268 |
+
return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
|
269 |
+
|
270 |
+
|
271 |
+
def poi_average(pois, settings):
|
272 |
+
weight = 0.0
|
273 |
+
x = 0.0
|
274 |
+
y = 0.0
|
275 |
+
for poi in pois:
|
276 |
+
weight += poi.weight
|
277 |
+
x += poi.x * poi.weight
|
278 |
+
y += poi.y * poi.weight
|
279 |
+
avg_x = round(weight and x / weight)
|
280 |
+
avg_y = round(weight and y / weight)
|
281 |
+
|
282 |
+
return PointOfInterest(avg_x, avg_y)
|
283 |
+
|
284 |
+
|
285 |
+
def is_landscape(w, h):
|
286 |
+
return w > h
|
287 |
+
|
288 |
+
|
289 |
+
def is_portrait(w, h):
|
290 |
+
return h > w
|
291 |
+
|
292 |
+
|
293 |
+
def is_square(w, h):
|
294 |
+
return w == h
|
295 |
+
|
296 |
+
|
297 |
+
def download_and_cache_models(dirname):
|
298 |
+
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
299 |
+
model_file_name = 'face_detection_yunet.onnx'
|
300 |
+
|
301 |
+
os.makedirs(dirname, exist_ok=True)
|
302 |
+
|
303 |
+
cache_file = os.path.join(dirname, model_file_name)
|
304 |
+
if not os.path.exists(cache_file):
|
305 |
+
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
306 |
+
response = requests.get(download_url)
|
307 |
+
with open(cache_file, "wb") as f:
|
308 |
+
f.write(response.content)
|
309 |
+
|
310 |
+
if os.path.exists(cache_file):
|
311 |
+
return cache_file
|
312 |
+
return None
|
313 |
+
|
314 |
+
|
315 |
+
class PointOfInterest:
|
316 |
+
def __init__(self, x, y, weight=1.0, size=10):
|
317 |
+
self.x = x
|
318 |
+
self.y = y
|
319 |
+
self.weight = weight
|
320 |
+
self.size = size
|
321 |
+
|
322 |
+
def bounding(self, size):
|
323 |
+
return [
|
324 |
+
self.x - size // 2,
|
325 |
+
self.y - size // 2,
|
326 |
+
self.x + size // 2,
|
327 |
+
self.y + size // 2
|
328 |
+
]
|
329 |
+
|
330 |
+
|
331 |
+
class Settings:
|
332 |
+
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
333 |
+
self.crop_width = crop_width
|
334 |
+
self.crop_height = crop_height
|
335 |
+
self.corner_points_weight = corner_points_weight
|
336 |
+
self.entropy_points_weight = entropy_points_weight
|
337 |
+
self.face_points_weight = face_points_weight
|
338 |
+
self.annotate_image = annotate_image
|
339 |
+
self.destop_view_image = False
|
340 |
+
self.dnn_model_path = dnn_model_path
|
modules/textual_inversion/dataset.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import PIL
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torch.utils.data import Dataset, DataLoader, Sampler
|
7 |
+
from torchvision import transforms
|
8 |
+
from collections import defaultdict
|
9 |
+
from random import shuffle, choices
|
10 |
+
|
11 |
+
import random
|
12 |
+
import tqdm
|
13 |
+
from modules import devices, shared
|
14 |
+
import re
|
15 |
+
|
16 |
+
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
17 |
+
|
18 |
+
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
19 |
+
|
20 |
+
|
21 |
+
class DatasetEntry:
|
22 |
+
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
|
23 |
+
self.filename = filename
|
24 |
+
self.filename_text = filename_text
|
25 |
+
self.weight = weight
|
26 |
+
self.latent_dist = latent_dist
|
27 |
+
self.latent_sample = latent_sample
|
28 |
+
self.cond = cond
|
29 |
+
self.cond_text = cond_text
|
30 |
+
self.pixel_values = pixel_values
|
31 |
+
|
32 |
+
|
33 |
+
class PersonalizedBase(Dataset):
|
34 |
+
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
|
35 |
+
re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
|
36 |
+
|
37 |
+
self.placeholder_token = placeholder_token
|
38 |
+
|
39 |
+
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
40 |
+
|
41 |
+
self.dataset = []
|
42 |
+
|
43 |
+
with open(template_file, "r") as file:
|
44 |
+
lines = [x.strip() for x in file.readlines()]
|
45 |
+
|
46 |
+
self.lines = lines
|
47 |
+
|
48 |
+
assert data_root, 'dataset directory not specified'
|
49 |
+
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
50 |
+
assert os.listdir(data_root), "Dataset directory is empty"
|
51 |
+
|
52 |
+
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
53 |
+
|
54 |
+
self.shuffle_tags = shuffle_tags
|
55 |
+
self.tag_drop_out = tag_drop_out
|
56 |
+
groups = defaultdict(list)
|
57 |
+
|
58 |
+
print("Preparing dataset...")
|
59 |
+
for path in tqdm.tqdm(self.image_paths):
|
60 |
+
alpha_channel = None
|
61 |
+
if shared.state.interrupted:
|
62 |
+
raise Exception("interrupted")
|
63 |
+
try:
|
64 |
+
image = Image.open(path)
|
65 |
+
#Currently does not work for single color transparency
|
66 |
+
#We would need to read image.info['transparency'] for that
|
67 |
+
if use_weight and 'A' in image.getbands():
|
68 |
+
alpha_channel = image.getchannel('A')
|
69 |
+
image = image.convert('RGB')
|
70 |
+
if not varsize:
|
71 |
+
image = image.resize((width, height), PIL.Image.BICUBIC)
|
72 |
+
except Exception:
|
73 |
+
continue
|
74 |
+
|
75 |
+
text_filename = f"{os.path.splitext(path)[0]}.txt"
|
76 |
+
filename = os.path.basename(path)
|
77 |
+
|
78 |
+
if os.path.exists(text_filename):
|
79 |
+
with open(text_filename, "r", encoding="utf8") as file:
|
80 |
+
filename_text = file.read()
|
81 |
+
else:
|
82 |
+
filename_text = os.path.splitext(filename)[0]
|
83 |
+
filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
84 |
+
if re_word:
|
85 |
+
tokens = re_word.findall(filename_text)
|
86 |
+
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
87 |
+
|
88 |
+
npimage = np.array(image).astype(np.uint8)
|
89 |
+
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
90 |
+
|
91 |
+
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
92 |
+
latent_sample = None
|
93 |
+
|
94 |
+
with devices.autocast():
|
95 |
+
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
96 |
+
|
97 |
+
#Perform latent sampling, even for random sampling.
|
98 |
+
#We need the sample dimensions for the weights
|
99 |
+
if latent_sampling_method == "deterministic":
|
100 |
+
if isinstance(latent_dist, DiagonalGaussianDistribution):
|
101 |
+
# Works only for DiagonalGaussianDistribution
|
102 |
+
latent_dist.std = 0
|
103 |
+
else:
|
104 |
+
latent_sampling_method = "once"
|
105 |
+
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
106 |
+
|
107 |
+
if use_weight and alpha_channel is not None:
|
108 |
+
channels, *latent_size = latent_sample.shape
|
109 |
+
weight_img = alpha_channel.resize(latent_size)
|
110 |
+
npweight = np.array(weight_img).astype(np.float32)
|
111 |
+
#Repeat for every channel in the latent sample
|
112 |
+
weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
|
113 |
+
#Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
|
114 |
+
weight -= weight.min()
|
115 |
+
weight /= weight.mean()
|
116 |
+
elif use_weight:
|
117 |
+
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
|
118 |
+
weight = torch.ones(latent_sample.shape)
|
119 |
+
else:
|
120 |
+
weight = None
|
121 |
+
|
122 |
+
if latent_sampling_method == "random":
|
123 |
+
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
124 |
+
else:
|
125 |
+
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
|
126 |
+
|
127 |
+
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
128 |
+
entry.cond_text = self.create_text(filename_text)
|
129 |
+
|
130 |
+
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
131 |
+
with devices.autocast():
|
132 |
+
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
133 |
+
groups[image.size].append(len(self.dataset))
|
134 |
+
self.dataset.append(entry)
|
135 |
+
del torchdata
|
136 |
+
del latent_dist
|
137 |
+
del latent_sample
|
138 |
+
del weight
|
139 |
+
|
140 |
+
self.length = len(self.dataset)
|
141 |
+
self.groups = list(groups.values())
|
142 |
+
assert self.length > 0, "No images have been found in the dataset."
|
143 |
+
self.batch_size = min(batch_size, self.length)
|
144 |
+
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
145 |
+
self.latent_sampling_method = latent_sampling_method
|
146 |
+
|
147 |
+
if len(groups) > 1:
|
148 |
+
print("Buckets:")
|
149 |
+
for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
|
150 |
+
print(f" {w}x{h}: {len(ids)}")
|
151 |
+
print()
|
152 |
+
|
153 |
+
def create_text(self, filename_text):
|
154 |
+
text = random.choice(self.lines)
|
155 |
+
tags = filename_text.split(',')
|
156 |
+
if self.tag_drop_out != 0:
|
157 |
+
tags = [t for t in tags if random.random() > self.tag_drop_out]
|
158 |
+
if self.shuffle_tags:
|
159 |
+
random.shuffle(tags)
|
160 |
+
text = text.replace("[filewords]", ','.join(tags))
|
161 |
+
text = text.replace("[name]", self.placeholder_token)
|
162 |
+
return text
|
163 |
+
|
164 |
+
def __len__(self):
|
165 |
+
return self.length
|
166 |
+
|
167 |
+
def __getitem__(self, i):
|
168 |
+
entry = self.dataset[i]
|
169 |
+
if self.tag_drop_out != 0 or self.shuffle_tags:
|
170 |
+
entry.cond_text = self.create_text(entry.filename_text)
|
171 |
+
if self.latent_sampling_method == "random":
|
172 |
+
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
173 |
+
return entry
|
174 |
+
|
175 |
+
|
176 |
+
class GroupedBatchSampler(Sampler):
|
177 |
+
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
178 |
+
super().__init__(data_source)
|
179 |
+
|
180 |
+
n = len(data_source)
|
181 |
+
self.groups = data_source.groups
|
182 |
+
self.len = n_batch = n // batch_size
|
183 |
+
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
184 |
+
self.base = [int(e) // batch_size for e in expected]
|
185 |
+
self.n_rand_batches = nrb = n_batch - sum(self.base)
|
186 |
+
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
187 |
+
self.batch_size = batch_size
|
188 |
+
|
189 |
+
def __len__(self):
|
190 |
+
return self.len
|
191 |
+
|
192 |
+
def __iter__(self):
|
193 |
+
b = self.batch_size
|
194 |
+
|
195 |
+
for g in self.groups:
|
196 |
+
shuffle(g)
|
197 |
+
|
198 |
+
batches = []
|
199 |
+
for g in self.groups:
|
200 |
+
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
201 |
+
for _ in range(self.n_rand_batches):
|
202 |
+
rand_group = choices(self.groups, self.probs)[0]
|
203 |
+
batches.append(choices(rand_group, k=b))
|
204 |
+
|
205 |
+
shuffle(batches)
|
206 |
+
|
207 |
+
yield from batches
|
208 |
+
|
209 |
+
|
210 |
+
class PersonalizedDataLoader(DataLoader):
|
211 |
+
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
212 |
+
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
213 |
+
if latent_sampling_method == "random":
|
214 |
+
self.collate_fn = collate_wrapper_random
|
215 |
+
else:
|
216 |
+
self.collate_fn = collate_wrapper
|
217 |
+
|
218 |
+
|
219 |
+
class BatchLoader:
|
220 |
+
def __init__(self, data):
|
221 |
+
self.cond_text = [entry.cond_text for entry in data]
|
222 |
+
self.cond = [entry.cond for entry in data]
|
223 |
+
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
224 |
+
if all(entry.weight is not None for entry in data):
|
225 |
+
self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
|
226 |
+
else:
|
227 |
+
self.weight = None
|
228 |
+
#self.emb_index = [entry.emb_index for entry in data]
|
229 |
+
#print(self.latent_sample.device)
|
230 |
+
|
231 |
+
def pin_memory(self):
|
232 |
+
self.latent_sample = self.latent_sample.pin_memory()
|
233 |
+
return self
|
234 |
+
|
235 |
+
def collate_wrapper(batch):
|
236 |
+
return BatchLoader(batch)
|
237 |
+
|
238 |
+
class BatchLoaderRandom(BatchLoader):
|
239 |
+
def __init__(self, data):
|
240 |
+
super().__init__(data)
|
241 |
+
|
242 |
+
def pin_memory(self):
|
243 |
+
return self
|
244 |
+
|
245 |
+
def collate_wrapper_random(batch):
|
246 |
+
return BatchLoaderRandom(batch)
|
modules/textual_inversion/image_embedding.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import zlib
|
7 |
+
from PIL import Image, ImageDraw
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class EmbeddingEncoder(json.JSONEncoder):
|
12 |
+
def default(self, obj):
|
13 |
+
if isinstance(obj, torch.Tensor):
|
14 |
+
return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
|
15 |
+
return json.JSONEncoder.default(self, obj)
|
16 |
+
|
17 |
+
|
18 |
+
class EmbeddingDecoder(json.JSONDecoder):
|
19 |
+
def __init__(self, *args, **kwargs):
|
20 |
+
json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
|
21 |
+
|
22 |
+
def object_hook(self, d):
|
23 |
+
if 'TORCHTENSOR' in d:
|
24 |
+
return torch.from_numpy(np.array(d['TORCHTENSOR']))
|
25 |
+
return d
|
26 |
+
|
27 |
+
|
28 |
+
def embedding_to_b64(data):
|
29 |
+
d = json.dumps(data, cls=EmbeddingEncoder)
|
30 |
+
return base64.b64encode(d.encode())
|
31 |
+
|
32 |
+
|
33 |
+
def embedding_from_b64(data):
|
34 |
+
d = base64.b64decode(data)
|
35 |
+
return json.loads(d, cls=EmbeddingDecoder)
|
36 |
+
|
37 |
+
|
38 |
+
def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
|
39 |
+
while True:
|
40 |
+
seed = (a * seed + c) % m
|
41 |
+
yield seed % 255
|
42 |
+
|
43 |
+
|
44 |
+
def xor_block(block):
|
45 |
+
g = lcg()
|
46 |
+
randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
|
47 |
+
return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
|
48 |
+
|
49 |
+
|
50 |
+
def style_block(block, sequence):
|
51 |
+
im = Image.new('RGB', (block.shape[1], block.shape[0]))
|
52 |
+
draw = ImageDraw.Draw(im)
|
53 |
+
i = 0
|
54 |
+
for x in range(-6, im.size[0], 8):
|
55 |
+
for yi, y in enumerate(range(-6, im.size[1], 8)):
|
56 |
+
offset = 0
|
57 |
+
if yi % 2 == 0:
|
58 |
+
offset = 4
|
59 |
+
shade = sequence[i % len(sequence)]
|
60 |
+
i += 1
|
61 |
+
draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
|
62 |
+
|
63 |
+
fg = np.array(im).astype(np.uint8) & 0xF0
|
64 |
+
|
65 |
+
return block ^ fg
|
66 |
+
|
67 |
+
|
68 |
+
def insert_image_data_embed(image, data):
|
69 |
+
d = 3
|
70 |
+
data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
|
71 |
+
data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
|
72 |
+
data_np_high = data_np_ >> 4
|
73 |
+
data_np_low = data_np_ & 0x0F
|
74 |
+
|
75 |
+
h = image.size[1]
|
76 |
+
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
|
77 |
+
next_size = next_size + ((h*d)-(next_size % (h*d)))
|
78 |
+
|
79 |
+
data_np_low = np.resize(data_np_low, next_size)
|
80 |
+
data_np_low = data_np_low.reshape((h, -1, d))
|
81 |
+
|
82 |
+
data_np_high = np.resize(data_np_high, next_size)
|
83 |
+
data_np_high = data_np_high.reshape((h, -1, d))
|
84 |
+
|
85 |
+
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
|
86 |
+
edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
|
87 |
+
|
88 |
+
data_np_low = style_block(data_np_low, sequence=edge_style)
|
89 |
+
data_np_low = xor_block(data_np_low)
|
90 |
+
data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
|
91 |
+
data_np_high = xor_block(data_np_high)
|
92 |
+
|
93 |
+
im_low = Image.fromarray(data_np_low, mode='RGB')
|
94 |
+
im_high = Image.fromarray(data_np_high, mode='RGB')
|
95 |
+
|
96 |
+
background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
|
97 |
+
background.paste(im_low, (0, 0))
|
98 |
+
background.paste(image, (im_low.size[0]+1, 0))
|
99 |
+
background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
|
100 |
+
|
101 |
+
return background
|
102 |
+
|
103 |
+
|
104 |
+
def crop_black(img, tol=0):
|
105 |
+
mask = (img > tol).all(2)
|
106 |
+
mask0, mask1 = mask.any(0), mask.any(1)
|
107 |
+
col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
|
108 |
+
row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
|
109 |
+
return img[row_start:row_end, col_start:col_end]
|
110 |
+
|
111 |
+
|
112 |
+
def extract_image_data_embed(image):
|
113 |
+
d = 3
|
114 |
+
outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
|
115 |
+
black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
|
116 |
+
if black_cols[0].shape[0] < 2:
|
117 |
+
print('No Image data blocks found.')
|
118 |
+
return None
|
119 |
+
|
120 |
+
data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
|
121 |
+
data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
|
122 |
+
|
123 |
+
data_block_lower = xor_block(data_block_lower)
|
124 |
+
data_block_upper = xor_block(data_block_upper)
|
125 |
+
|
126 |
+
data_block = (data_block_upper << 4) | (data_block_lower)
|
127 |
+
data_block = data_block.flatten().tobytes()
|
128 |
+
|
129 |
+
data = zlib.decompress(data_block)
|
130 |
+
return json.loads(data, cls=EmbeddingDecoder)
|
131 |
+
|
132 |
+
|
133 |
+
def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
|
134 |
+
from modules.images import get_font
|
135 |
+
if textfont:
|
136 |
+
warnings.warn(
|
137 |
+
'passing in a textfont to caption_image_overlay is deprecated and does nothing',
|
138 |
+
DeprecationWarning,
|
139 |
+
stacklevel=2,
|
140 |
+
)
|
141 |
+
from math import cos
|
142 |
+
|
143 |
+
image = srcimage.copy()
|
144 |
+
fontsize = 32
|
145 |
+
factor = 1.5
|
146 |
+
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
|
147 |
+
for y in range(image.size[1]):
|
148 |
+
mag = 1-cos(y/image.size[1]*factor)
|
149 |
+
mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
|
150 |
+
gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
|
151 |
+
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
|
152 |
+
|
153 |
+
draw = ImageDraw.Draw(image)
|
154 |
+
|
155 |
+
font = get_font(fontsize)
|
156 |
+
padding = 10
|
157 |
+
|
158 |
+
_, _, w, h = draw.textbbox((0, 0), title, font=font)
|
159 |
+
fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
|
160 |
+
font = get_font(fontsize)
|
161 |
+
_, _, w, h = draw.textbbox((0, 0), title, font=font)
|
162 |
+
draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
|
163 |
+
|
164 |
+
_, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
|
165 |
+
fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
|
166 |
+
_, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
|
167 |
+
fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
|
168 |
+
_, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
|
169 |
+
fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
|
170 |
+
|
171 |
+
font = get_font(min(fontsize_left, fontsize_mid, fontsize_right))
|
172 |
+
|
173 |
+
draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
|
174 |
+
draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
|
175 |
+
draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
|
176 |
+
|
177 |
+
return image
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == '__main__':
|
181 |
+
|
182 |
+
testEmbed = Image.open('test_embedding.png')
|
183 |
+
data = extract_image_data_embed(testEmbed)
|
184 |
+
assert data is not None
|
185 |
+
|
186 |
+
data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
|
187 |
+
assert data is not None
|
188 |
+
|
189 |
+
image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
|
190 |
+
cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
|
191 |
+
|
192 |
+
test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
|
193 |
+
|
194 |
+
embedded_image = insert_image_data_embed(cap_image, test_embed)
|
195 |
+
|
196 |
+
retrived_embed = extract_image_data_embed(embedded_image)
|
197 |
+
|
198 |
+
assert str(retrived_embed) == str(test_embed)
|
199 |
+
|
200 |
+
embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
|
201 |
+
|
202 |
+
assert embedded_image == embedded_image2
|
203 |
+
|
204 |
+
g = lcg()
|
205 |
+
shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
|
206 |
+
|
207 |
+
reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
|
208 |
+
95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
|
209 |
+
160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
|
210 |
+
38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
|
211 |
+
30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
|
212 |
+
41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
|
213 |
+
66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
|
214 |
+
204, 86, 73, 222, 44, 198, 118, 240, 97]
|
215 |
+
|
216 |
+
assert shared_random == reference_random
|
217 |
+
|
218 |
+
hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
|
219 |
+
|
220 |
+
assert 12731374 == hunna_kay_random_sum
|
modules/textual_inversion/learn_schedule.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tqdm
|
2 |
+
|
3 |
+
|
4 |
+
class LearnScheduleIterator:
|
5 |
+
def __init__(self, learn_rate, max_steps, cur_step=0):
|
6 |
+
"""
|
7 |
+
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
|
8 |
+
"""
|
9 |
+
|
10 |
+
pairs = learn_rate.split(',')
|
11 |
+
self.rates = []
|
12 |
+
self.it = 0
|
13 |
+
self.maxit = 0
|
14 |
+
try:
|
15 |
+
for pair in pairs:
|
16 |
+
if not pair.strip():
|
17 |
+
continue
|
18 |
+
tmp = pair.split(':')
|
19 |
+
if len(tmp) == 2:
|
20 |
+
step = int(tmp[1])
|
21 |
+
if step > cur_step:
|
22 |
+
self.rates.append((float(tmp[0]), min(step, max_steps)))
|
23 |
+
self.maxit += 1
|
24 |
+
if step > max_steps:
|
25 |
+
return
|
26 |
+
elif step == -1:
|
27 |
+
self.rates.append((float(tmp[0]), max_steps))
|
28 |
+
self.maxit += 1
|
29 |
+
return
|
30 |
+
else:
|
31 |
+
self.rates.append((float(tmp[0]), max_steps))
|
32 |
+
self.maxit += 1
|
33 |
+
return
|
34 |
+
assert self.rates
|
35 |
+
except (ValueError, AssertionError) as e:
|
36 |
+
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
|
37 |
+
|
38 |
+
|
39 |
+
def __iter__(self):
|
40 |
+
return self
|
41 |
+
|
42 |
+
def __next__(self):
|
43 |
+
if self.it < self.maxit:
|
44 |
+
self.it += 1
|
45 |
+
return self.rates[self.it - 1]
|
46 |
+
else:
|
47 |
+
raise StopIteration
|
48 |
+
|
49 |
+
|
50 |
+
class LearnRateScheduler:
|
51 |
+
def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
|
52 |
+
self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
|
53 |
+
(self.learn_rate, self.end_step) = next(self.schedules)
|
54 |
+
self.verbose = verbose
|
55 |
+
|
56 |
+
if self.verbose:
|
57 |
+
print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
|
58 |
+
|
59 |
+
self.finished = False
|
60 |
+
|
61 |
+
def step(self, step_number):
|
62 |
+
if step_number < self.end_step:
|
63 |
+
return False
|
64 |
+
|
65 |
+
try:
|
66 |
+
(self.learn_rate, self.end_step) = next(self.schedules)
|
67 |
+
except StopIteration:
|
68 |
+
self.finished = True
|
69 |
+
return False
|
70 |
+
return True
|
71 |
+
|
72 |
+
def apply(self, optimizer, step_number):
|
73 |
+
if not self.step(step_number):
|
74 |
+
return
|
75 |
+
|
76 |
+
if self.verbose:
|
77 |
+
tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
|
78 |
+
|
79 |
+
for pg in optimizer.param_groups:
|
80 |
+
pg['lr'] = self.learn_rate
|
81 |
+
|