5ddaf0cca1e144f6e16563ce10f39c55b15e5da289b24cf2ef5d01fe66eaa922
Browse files- modules/call_queue.py +117 -0
- modules/cmd_args.py +113 -0
- modules/codeformer/__pycache__/codeformer_arch.cpython-310.pyc +0 -0
- modules/codeformer/__pycache__/vqgan_arch.cpython-310.pyc +0 -0
- modules/codeformer/codeformer_arch.py +276 -0
- modules/codeformer/vqgan_arch.py +435 -0
- modules/codeformer_model.py +132 -0
- modules/config_states.py +197 -0
- modules/deepbooru.py +98 -0
- modules/deepbooru_model.py +678 -0
- modules/devices.py +171 -0
- modules/errors.py +85 -0
- modules/esrgan_model.py +229 -0
- modules/esrgan_model_arch.py +465 -0
- modules/extensions.py +163 -0
- modules/extra_networks.py +179 -0
- modules/extra_networks_hypernet.py +28 -0
- modules/extras.py +303 -0
- modules/face_restoration.py +19 -0
- modules/generation_parameters_copypaste.py +439 -0
- modules/gfpgan_model.py +110 -0
- modules/gitpython_hack.py +42 -0
- modules/hashes.py +81 -0
- modules/hypernetworks/__pycache__/hypernetwork.cpython-310.pyc +0 -0
- modules/hypernetworks/__pycache__/ui.cpython-310.pyc +0 -0
- modules/hypernetworks/hypernetwork.py +783 -0
- modules/hypernetworks/ui.py +38 -0
- modules/images.py +758 -0
- modules/img2img.py +245 -0
- modules/import_hook.py +5 -0
- modules/interrogate.py +223 -0
- modules/launch_utils.py +415 -0
- modules/localization.py +35 -0
- modules/lowvram.py +130 -0
- modules/mac_specific.py +86 -0
- modules/masking.py +99 -0
- modules/memmon.py +92 -0
- modules/modelloader.py +179 -0
- modules/models/diffusion/ddpm_edit.py +1455 -0
- modules/models/diffusion/uni_pc/__init__.py +1 -0
- modules/models/diffusion/uni_pc/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/models/diffusion/uni_pc/__pycache__/sampler.cpython-310.pyc +0 -0
- modules/models/diffusion/uni_pc/__pycache__/uni_pc.cpython-310.pyc +0 -0
- modules/models/diffusion/uni_pc/sampler.py +101 -0
- modules/models/diffusion/uni_pc/uni_pc.py +863 -0
- modules/ngrok.py +30 -0
- modules/paths.py +65 -0
- modules/paths_internal.py +31 -0
- modules/postprocessing.py +109 -0
- modules/processing.py +1405 -0
modules/call_queue.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import wraps
|
2 |
+
import html
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
|
6 |
+
from modules import shared, progress, errors
|
7 |
+
|
8 |
+
queue_lock = threading.Lock()
|
9 |
+
|
10 |
+
|
11 |
+
def wrap_queued_call(func):
|
12 |
+
def f(*args, **kwargs):
|
13 |
+
with queue_lock:
|
14 |
+
res = func(*args, **kwargs)
|
15 |
+
|
16 |
+
return res
|
17 |
+
|
18 |
+
return f
|
19 |
+
|
20 |
+
|
21 |
+
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
22 |
+
@wraps(func)
|
23 |
+
def f(*args, **kwargs):
|
24 |
+
|
25 |
+
# if the first argument is a string that says "task(...)", it is treated as a job id
|
26 |
+
if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
|
27 |
+
id_task = args[0]
|
28 |
+
progress.add_task_to_queue(id_task)
|
29 |
+
else:
|
30 |
+
id_task = None
|
31 |
+
|
32 |
+
with queue_lock:
|
33 |
+
shared.state.begin(job=id_task)
|
34 |
+
progress.start_task(id_task)
|
35 |
+
|
36 |
+
try:
|
37 |
+
res = func(*args, **kwargs)
|
38 |
+
progress.record_results(id_task, res)
|
39 |
+
finally:
|
40 |
+
progress.finish_task(id_task)
|
41 |
+
|
42 |
+
shared.state.end()
|
43 |
+
|
44 |
+
return res
|
45 |
+
|
46 |
+
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
|
47 |
+
|
48 |
+
|
49 |
+
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
50 |
+
@wraps(func)
|
51 |
+
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
52 |
+
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
53 |
+
if run_memmon:
|
54 |
+
shared.mem_mon.monitor()
|
55 |
+
t = time.perf_counter()
|
56 |
+
|
57 |
+
try:
|
58 |
+
res = list(func(*args, **kwargs))
|
59 |
+
except Exception as e:
|
60 |
+
# When printing out our debug argument list,
|
61 |
+
# do not print out more than a 100 KB of text
|
62 |
+
max_debug_str_len = 131072
|
63 |
+
message = "Error completing request"
|
64 |
+
arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
|
65 |
+
if len(arg_str) > max_debug_str_len:
|
66 |
+
arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
|
67 |
+
errors.report(f"{message}\n{arg_str}", exc_info=True)
|
68 |
+
|
69 |
+
shared.state.job = ""
|
70 |
+
shared.state.job_count = 0
|
71 |
+
|
72 |
+
if extra_outputs_array is None:
|
73 |
+
extra_outputs_array = [None, '']
|
74 |
+
|
75 |
+
error_message = f'{type(e).__name__}: {e}'
|
76 |
+
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
|
77 |
+
|
78 |
+
shared.state.skipped = False
|
79 |
+
shared.state.interrupted = False
|
80 |
+
shared.state.job_count = 0
|
81 |
+
|
82 |
+
if not add_stats:
|
83 |
+
return tuple(res)
|
84 |
+
|
85 |
+
elapsed = time.perf_counter() - t
|
86 |
+
elapsed_m = int(elapsed // 60)
|
87 |
+
elapsed_s = elapsed % 60
|
88 |
+
elapsed_text = f"{elapsed_s:.1f} sec."
|
89 |
+
if elapsed_m > 0:
|
90 |
+
elapsed_text = f"{elapsed_m} min. "+elapsed_text
|
91 |
+
|
92 |
+
if run_memmon:
|
93 |
+
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
94 |
+
active_peak = mem_stats['active_peak']
|
95 |
+
reserved_peak = mem_stats['reserved_peak']
|
96 |
+
sys_peak = mem_stats['system_peak']
|
97 |
+
sys_total = mem_stats['total']
|
98 |
+
sys_pct = sys_peak/max(sys_total, 1) * 100
|
99 |
+
|
100 |
+
toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)"
|
101 |
+
toltip_r = "Reserved: total amout of video memory allocated by the Torch library "
|
102 |
+
toltip_sys = "System: peak amout of video memory allocated by all running programs, out of total capacity"
|
103 |
+
|
104 |
+
text_a = f"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>"
|
105 |
+
text_r = f"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>"
|
106 |
+
text_sys = f"<abbr title='{toltip_sys}'>Sys</abbr>: <span class='measurement'>{sys_peak/1024:.1f}/{sys_total/1024:g} GB</span> ({sys_pct:.1f}%)"
|
107 |
+
|
108 |
+
vram_html = f"<p class='vram'>{text_a}, <wbr>{text_r}, <wbr>{text_sys}</p>"
|
109 |
+
else:
|
110 |
+
vram_html = ''
|
111 |
+
|
112 |
+
# last item is always HTML
|
113 |
+
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}</div>"
|
114 |
+
|
115 |
+
return tuple(res)
|
116 |
+
|
117 |
+
return f
|
modules/cmd_args.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
5 |
+
|
6 |
+
parser = argparse.ArgumentParser()
|
7 |
+
|
8 |
+
parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
|
9 |
+
parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
|
10 |
+
parser.add_argument("--skip-python-version-check", action='store_true', help="launch.py argument: do not check python version")
|
11 |
+
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
|
12 |
+
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
|
13 |
+
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
14 |
+
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
|
15 |
+
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
16 |
+
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
|
17 |
+
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
18 |
+
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
19 |
+
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
|
20 |
+
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
21 |
+
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
22 |
+
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
23 |
+
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
24 |
+
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
|
25 |
+
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
26 |
+
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
27 |
+
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
28 |
+
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
29 |
+
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
30 |
+
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
31 |
+
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
32 |
+
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
33 |
+
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
34 |
+
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
35 |
+
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
36 |
+
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
37 |
+
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
38 |
+
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
39 |
+
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
40 |
+
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
41 |
+
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
42 |
+
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
43 |
+
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
44 |
+
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
45 |
+
parser.add_argument("--ngrok-region", type=str, help="does not do anything.", default="")
|
46 |
+
parser.add_argument("--ngrok-options", type=json.loads, help='The options to pass to ngrok in JSON format, e.g.: \'{"authtoken_from_env":true, "basic_auth":"user:password", "oauth_provider":"google", "oauth_allow_emails":"user@asdf.com"}\'', default=dict())
|
47 |
+
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
48 |
+
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
49 |
+
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
50 |
+
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
51 |
+
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
52 |
+
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
53 |
+
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
54 |
+
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
55 |
+
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
56 |
+
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
57 |
+
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
58 |
+
parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
|
59 |
+
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
|
60 |
+
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
61 |
+
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
62 |
+
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
63 |
+
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
|
64 |
+
parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
|
65 |
+
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
|
66 |
+
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
|
67 |
+
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
68 |
+
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
69 |
+
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
70 |
+
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
71 |
+
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
72 |
+
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
73 |
+
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
74 |
+
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
75 |
+
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
76 |
+
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
77 |
+
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
78 |
+
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
79 |
+
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
80 |
+
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
81 |
+
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
82 |
+
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it")
|
83 |
+
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
84 |
+
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
85 |
+
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
86 |
+
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
87 |
+
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
88 |
+
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
89 |
+
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
90 |
+
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
91 |
+
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
92 |
+
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
93 |
+
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
94 |
+
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
|
95 |
+
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
96 |
+
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
97 |
+
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
98 |
+
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
|
99 |
+
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
|
100 |
+
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
101 |
+
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
102 |
+
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
103 |
+
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
|
104 |
+
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
105 |
+
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
106 |
+
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
|
107 |
+
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
108 |
+
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
109 |
+
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
110 |
+
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
111 |
+
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
112 |
+
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
113 |
+
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
modules/codeformer/__pycache__/codeformer_arch.cpython-310.pyc
ADDED
Binary file (9.16 kB). View file
|
|
modules/codeformer/__pycache__/vqgan_arch.cpython-310.pyc
ADDED
Binary file (11.1 kB). View file
|
|
modules/codeformer/codeformer_arch.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
2 |
+
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
|
10 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
11 |
+
|
12 |
+
def calc_mean_std(feat, eps=1e-5):
|
13 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
feat (Tensor): 4D tensor.
|
17 |
+
eps (float): A small value added to the variance to avoid
|
18 |
+
divide-by-zero. Default: 1e-5.
|
19 |
+
"""
|
20 |
+
size = feat.size()
|
21 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
22 |
+
b, c = size[:2]
|
23 |
+
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
24 |
+
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
25 |
+
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
26 |
+
return feat_mean, feat_std
|
27 |
+
|
28 |
+
|
29 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
30 |
+
"""Adaptive instance normalization.
|
31 |
+
|
32 |
+
Adjust the reference features to have the similar color and illuminations
|
33 |
+
as those in the degradate features.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
content_feat (Tensor): The reference feature.
|
37 |
+
style_feat (Tensor): The degradate features.
|
38 |
+
"""
|
39 |
+
size = content_feat.size()
|
40 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
41 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
42 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
43 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
44 |
+
|
45 |
+
|
46 |
+
class PositionEmbeddingSine(nn.Module):
|
47 |
+
"""
|
48 |
+
This is a more standard version of the position embedding, very similar to the one
|
49 |
+
used by the Attention is all you need paper, generalized to work on images.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
53 |
+
super().__init__()
|
54 |
+
self.num_pos_feats = num_pos_feats
|
55 |
+
self.temperature = temperature
|
56 |
+
self.normalize = normalize
|
57 |
+
if scale is not None and normalize is False:
|
58 |
+
raise ValueError("normalize should be True if scale is passed")
|
59 |
+
if scale is None:
|
60 |
+
scale = 2 * math.pi
|
61 |
+
self.scale = scale
|
62 |
+
|
63 |
+
def forward(self, x, mask=None):
|
64 |
+
if mask is None:
|
65 |
+
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
66 |
+
not_mask = ~mask
|
67 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
68 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
69 |
+
if self.normalize:
|
70 |
+
eps = 1e-6
|
71 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
72 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
73 |
+
|
74 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
75 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
76 |
+
|
77 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
78 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
79 |
+
pos_x = torch.stack(
|
80 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
81 |
+
).flatten(3)
|
82 |
+
pos_y = torch.stack(
|
83 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
84 |
+
).flatten(3)
|
85 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
86 |
+
return pos
|
87 |
+
|
88 |
+
def _get_activation_fn(activation):
|
89 |
+
"""Return an activation function given a string"""
|
90 |
+
if activation == "relu":
|
91 |
+
return F.relu
|
92 |
+
if activation == "gelu":
|
93 |
+
return F.gelu
|
94 |
+
if activation == "glu":
|
95 |
+
return F.glu
|
96 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
97 |
+
|
98 |
+
|
99 |
+
class TransformerSALayer(nn.Module):
|
100 |
+
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
101 |
+
super().__init__()
|
102 |
+
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
103 |
+
# Implementation of Feedforward model - MLP
|
104 |
+
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
105 |
+
self.dropout = nn.Dropout(dropout)
|
106 |
+
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
107 |
+
|
108 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
109 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
110 |
+
self.dropout1 = nn.Dropout(dropout)
|
111 |
+
self.dropout2 = nn.Dropout(dropout)
|
112 |
+
|
113 |
+
self.activation = _get_activation_fn(activation)
|
114 |
+
|
115 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
116 |
+
return tensor if pos is None else tensor + pos
|
117 |
+
|
118 |
+
def forward(self, tgt,
|
119 |
+
tgt_mask: Optional[Tensor] = None,
|
120 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
121 |
+
query_pos: Optional[Tensor] = None):
|
122 |
+
|
123 |
+
# self attention
|
124 |
+
tgt2 = self.norm1(tgt)
|
125 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
126 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
127 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
128 |
+
tgt = tgt + self.dropout1(tgt2)
|
129 |
+
|
130 |
+
# ffn
|
131 |
+
tgt2 = self.norm2(tgt)
|
132 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
133 |
+
tgt = tgt + self.dropout2(tgt2)
|
134 |
+
return tgt
|
135 |
+
|
136 |
+
class Fuse_sft_block(nn.Module):
|
137 |
+
def __init__(self, in_ch, out_ch):
|
138 |
+
super().__init__()
|
139 |
+
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
140 |
+
|
141 |
+
self.scale = nn.Sequential(
|
142 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
143 |
+
nn.LeakyReLU(0.2, True),
|
144 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
145 |
+
|
146 |
+
self.shift = nn.Sequential(
|
147 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
148 |
+
nn.LeakyReLU(0.2, True),
|
149 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
150 |
+
|
151 |
+
def forward(self, enc_feat, dec_feat, w=1):
|
152 |
+
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
153 |
+
scale = self.scale(enc_feat)
|
154 |
+
shift = self.shift(enc_feat)
|
155 |
+
residual = w * (dec_feat * scale + shift)
|
156 |
+
out = dec_feat + residual
|
157 |
+
return out
|
158 |
+
|
159 |
+
|
160 |
+
@ARCH_REGISTRY.register()
|
161 |
+
class CodeFormer(VQAutoEncoder):
|
162 |
+
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
163 |
+
codebook_size=1024, latent_size=256,
|
164 |
+
connect_list=('32', '64', '128', '256'),
|
165 |
+
fix_modules=('quantize', 'generator')):
|
166 |
+
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
167 |
+
|
168 |
+
if fix_modules is not None:
|
169 |
+
for module in fix_modules:
|
170 |
+
for param in getattr(self, module).parameters():
|
171 |
+
param.requires_grad = False
|
172 |
+
|
173 |
+
self.connect_list = connect_list
|
174 |
+
self.n_layers = n_layers
|
175 |
+
self.dim_embd = dim_embd
|
176 |
+
self.dim_mlp = dim_embd*2
|
177 |
+
|
178 |
+
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
179 |
+
self.feat_emb = nn.Linear(256, self.dim_embd)
|
180 |
+
|
181 |
+
# transformer
|
182 |
+
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
183 |
+
for _ in range(self.n_layers)])
|
184 |
+
|
185 |
+
# logits_predict head
|
186 |
+
self.idx_pred_layer = nn.Sequential(
|
187 |
+
nn.LayerNorm(dim_embd),
|
188 |
+
nn.Linear(dim_embd, codebook_size, bias=False))
|
189 |
+
|
190 |
+
self.channels = {
|
191 |
+
'16': 512,
|
192 |
+
'32': 256,
|
193 |
+
'64': 256,
|
194 |
+
'128': 128,
|
195 |
+
'256': 128,
|
196 |
+
'512': 64,
|
197 |
+
}
|
198 |
+
|
199 |
+
# after second residual block for > 16, before attn layer for ==16
|
200 |
+
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
201 |
+
# after first residual block for > 16, before attn layer for ==16
|
202 |
+
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
203 |
+
|
204 |
+
# fuse_convs_dict
|
205 |
+
self.fuse_convs_dict = nn.ModuleDict()
|
206 |
+
for f_size in self.connect_list:
|
207 |
+
in_ch = self.channels[f_size]
|
208 |
+
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
209 |
+
|
210 |
+
def _init_weights(self, module):
|
211 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
212 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
213 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
214 |
+
module.bias.data.zero_()
|
215 |
+
elif isinstance(module, nn.LayerNorm):
|
216 |
+
module.bias.data.zero_()
|
217 |
+
module.weight.data.fill_(1.0)
|
218 |
+
|
219 |
+
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
220 |
+
# ################### Encoder #####################
|
221 |
+
enc_feat_dict = {}
|
222 |
+
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
223 |
+
for i, block in enumerate(self.encoder.blocks):
|
224 |
+
x = block(x)
|
225 |
+
if i in out_list:
|
226 |
+
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
227 |
+
|
228 |
+
lq_feat = x
|
229 |
+
# ################# Transformer ###################
|
230 |
+
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
231 |
+
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
232 |
+
# BCHW -> BC(HW) -> (HW)BC
|
233 |
+
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
234 |
+
query_emb = feat_emb
|
235 |
+
# Transformer encoder
|
236 |
+
for layer in self.ft_layers:
|
237 |
+
query_emb = layer(query_emb, query_pos=pos_emb)
|
238 |
+
|
239 |
+
# output logits
|
240 |
+
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
241 |
+
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
242 |
+
|
243 |
+
if code_only: # for training stage II
|
244 |
+
# logits doesn't need softmax before cross_entropy loss
|
245 |
+
return logits, lq_feat
|
246 |
+
|
247 |
+
# ################# Quantization ###################
|
248 |
+
# if self.training:
|
249 |
+
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
250 |
+
# # b(hw)c -> bc(hw) -> bchw
|
251 |
+
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
252 |
+
# ------------
|
253 |
+
soft_one_hot = F.softmax(logits, dim=2)
|
254 |
+
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
255 |
+
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
256 |
+
# preserve gradients
|
257 |
+
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
258 |
+
|
259 |
+
if detach_16:
|
260 |
+
quant_feat = quant_feat.detach() # for training stage III
|
261 |
+
if adain:
|
262 |
+
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
263 |
+
|
264 |
+
# ################## Generator ####################
|
265 |
+
x = quant_feat
|
266 |
+
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
267 |
+
|
268 |
+
for i, block in enumerate(self.generator.blocks):
|
269 |
+
x = block(x)
|
270 |
+
if i in fuse_list: # fuse after i-th block
|
271 |
+
f_size = str(x.shape[-1])
|
272 |
+
if w>0:
|
273 |
+
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
274 |
+
out = x
|
275 |
+
# logits doesn't need softmax before cross_entropy loss
|
276 |
+
return out, logits, lq_feat
|
modules/codeformer/vqgan_arch.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
2 |
+
|
3 |
+
'''
|
4 |
+
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
5 |
+
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
6 |
+
|
7 |
+
'''
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from basicsr.utils import get_root_logger
|
12 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
13 |
+
|
14 |
+
def normalize(in_channels):
|
15 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
16 |
+
|
17 |
+
|
18 |
+
@torch.jit.script
|
19 |
+
def swish(x):
|
20 |
+
return x*torch.sigmoid(x)
|
21 |
+
|
22 |
+
|
23 |
+
# Define VQVAE classes
|
24 |
+
class VectorQuantizer(nn.Module):
|
25 |
+
def __init__(self, codebook_size, emb_dim, beta):
|
26 |
+
super(VectorQuantizer, self).__init__()
|
27 |
+
self.codebook_size = codebook_size # number of embeddings
|
28 |
+
self.emb_dim = emb_dim # dimension of embedding
|
29 |
+
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
30 |
+
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
31 |
+
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
32 |
+
|
33 |
+
def forward(self, z):
|
34 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
35 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
36 |
+
z_flattened = z.view(-1, self.emb_dim)
|
37 |
+
|
38 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
39 |
+
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
40 |
+
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
41 |
+
|
42 |
+
mean_distance = torch.mean(d)
|
43 |
+
# find closest encodings
|
44 |
+
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
45 |
+
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
46 |
+
# [0-1], higher score, higher confidence
|
47 |
+
min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
48 |
+
|
49 |
+
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
50 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
51 |
+
|
52 |
+
# get quantized latent vectors
|
53 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
54 |
+
# compute loss for embedding
|
55 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
56 |
+
# preserve gradients
|
57 |
+
z_q = z + (z_q - z).detach()
|
58 |
+
|
59 |
+
# perplexity
|
60 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
61 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
62 |
+
# reshape back to match original input shape
|
63 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
64 |
+
|
65 |
+
return z_q, loss, {
|
66 |
+
"perplexity": perplexity,
|
67 |
+
"min_encodings": min_encodings,
|
68 |
+
"min_encoding_indices": min_encoding_indices,
|
69 |
+
"min_encoding_scores": min_encoding_scores,
|
70 |
+
"mean_distance": mean_distance
|
71 |
+
}
|
72 |
+
|
73 |
+
def get_codebook_feat(self, indices, shape):
|
74 |
+
# input indices: batch*token_num -> (batch*token_num)*1
|
75 |
+
# shape: batch, height, width, channel
|
76 |
+
indices = indices.view(-1,1)
|
77 |
+
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
78 |
+
min_encodings.scatter_(1, indices, 1)
|
79 |
+
# get quantized latent vectors
|
80 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
81 |
+
|
82 |
+
if shape is not None: # reshape back to match original input shape
|
83 |
+
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
84 |
+
|
85 |
+
return z_q
|
86 |
+
|
87 |
+
|
88 |
+
class GumbelQuantizer(nn.Module):
|
89 |
+
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
90 |
+
super().__init__()
|
91 |
+
self.codebook_size = codebook_size # number of embeddings
|
92 |
+
self.emb_dim = emb_dim # dimension of embedding
|
93 |
+
self.straight_through = straight_through
|
94 |
+
self.temperature = temp_init
|
95 |
+
self.kl_weight = kl_weight
|
96 |
+
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
97 |
+
self.embed = nn.Embedding(codebook_size, emb_dim)
|
98 |
+
|
99 |
+
def forward(self, z):
|
100 |
+
hard = self.straight_through if self.training else True
|
101 |
+
|
102 |
+
logits = self.proj(z)
|
103 |
+
|
104 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
105 |
+
|
106 |
+
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
107 |
+
|
108 |
+
# + kl divergence to the prior loss
|
109 |
+
qy = F.softmax(logits, dim=1)
|
110 |
+
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
111 |
+
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
112 |
+
|
113 |
+
return z_q, diff, {
|
114 |
+
"min_encoding_indices": min_encoding_indices
|
115 |
+
}
|
116 |
+
|
117 |
+
|
118 |
+
class Downsample(nn.Module):
|
119 |
+
def __init__(self, in_channels):
|
120 |
+
super().__init__()
|
121 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
pad = (0, 1, 0, 1)
|
125 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
126 |
+
x = self.conv(x)
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
class Upsample(nn.Module):
|
131 |
+
def __init__(self, in_channels):
|
132 |
+
super().__init__()
|
133 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
137 |
+
x = self.conv(x)
|
138 |
+
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
class ResBlock(nn.Module):
|
143 |
+
def __init__(self, in_channels, out_channels=None):
|
144 |
+
super(ResBlock, self).__init__()
|
145 |
+
self.in_channels = in_channels
|
146 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
147 |
+
self.norm1 = normalize(in_channels)
|
148 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
149 |
+
self.norm2 = normalize(out_channels)
|
150 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
151 |
+
if self.in_channels != self.out_channels:
|
152 |
+
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
153 |
+
|
154 |
+
def forward(self, x_in):
|
155 |
+
x = x_in
|
156 |
+
x = self.norm1(x)
|
157 |
+
x = swish(x)
|
158 |
+
x = self.conv1(x)
|
159 |
+
x = self.norm2(x)
|
160 |
+
x = swish(x)
|
161 |
+
x = self.conv2(x)
|
162 |
+
if self.in_channels != self.out_channels:
|
163 |
+
x_in = self.conv_out(x_in)
|
164 |
+
|
165 |
+
return x + x_in
|
166 |
+
|
167 |
+
|
168 |
+
class AttnBlock(nn.Module):
|
169 |
+
def __init__(self, in_channels):
|
170 |
+
super().__init__()
|
171 |
+
self.in_channels = in_channels
|
172 |
+
|
173 |
+
self.norm = normalize(in_channels)
|
174 |
+
self.q = torch.nn.Conv2d(
|
175 |
+
in_channels,
|
176 |
+
in_channels,
|
177 |
+
kernel_size=1,
|
178 |
+
stride=1,
|
179 |
+
padding=0
|
180 |
+
)
|
181 |
+
self.k = torch.nn.Conv2d(
|
182 |
+
in_channels,
|
183 |
+
in_channels,
|
184 |
+
kernel_size=1,
|
185 |
+
stride=1,
|
186 |
+
padding=0
|
187 |
+
)
|
188 |
+
self.v = torch.nn.Conv2d(
|
189 |
+
in_channels,
|
190 |
+
in_channels,
|
191 |
+
kernel_size=1,
|
192 |
+
stride=1,
|
193 |
+
padding=0
|
194 |
+
)
|
195 |
+
self.proj_out = torch.nn.Conv2d(
|
196 |
+
in_channels,
|
197 |
+
in_channels,
|
198 |
+
kernel_size=1,
|
199 |
+
stride=1,
|
200 |
+
padding=0
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
h_ = x
|
205 |
+
h_ = self.norm(h_)
|
206 |
+
q = self.q(h_)
|
207 |
+
k = self.k(h_)
|
208 |
+
v = self.v(h_)
|
209 |
+
|
210 |
+
# compute attention
|
211 |
+
b, c, h, w = q.shape
|
212 |
+
q = q.reshape(b, c, h*w)
|
213 |
+
q = q.permute(0, 2, 1)
|
214 |
+
k = k.reshape(b, c, h*w)
|
215 |
+
w_ = torch.bmm(q, k)
|
216 |
+
w_ = w_ * (int(c)**(-0.5))
|
217 |
+
w_ = F.softmax(w_, dim=2)
|
218 |
+
|
219 |
+
# attend to values
|
220 |
+
v = v.reshape(b, c, h*w)
|
221 |
+
w_ = w_.permute(0, 2, 1)
|
222 |
+
h_ = torch.bmm(v, w_)
|
223 |
+
h_ = h_.reshape(b, c, h, w)
|
224 |
+
|
225 |
+
h_ = self.proj_out(h_)
|
226 |
+
|
227 |
+
return x+h_
|
228 |
+
|
229 |
+
|
230 |
+
class Encoder(nn.Module):
|
231 |
+
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
232 |
+
super().__init__()
|
233 |
+
self.nf = nf
|
234 |
+
self.num_resolutions = len(ch_mult)
|
235 |
+
self.num_res_blocks = num_res_blocks
|
236 |
+
self.resolution = resolution
|
237 |
+
self.attn_resolutions = attn_resolutions
|
238 |
+
|
239 |
+
curr_res = self.resolution
|
240 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
241 |
+
|
242 |
+
blocks = []
|
243 |
+
# initial convultion
|
244 |
+
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
245 |
+
|
246 |
+
# residual and downsampling blocks, with attention on smaller res (16x16)
|
247 |
+
for i in range(self.num_resolutions):
|
248 |
+
block_in_ch = nf * in_ch_mult[i]
|
249 |
+
block_out_ch = nf * ch_mult[i]
|
250 |
+
for _ in range(self.num_res_blocks):
|
251 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
252 |
+
block_in_ch = block_out_ch
|
253 |
+
if curr_res in attn_resolutions:
|
254 |
+
blocks.append(AttnBlock(block_in_ch))
|
255 |
+
|
256 |
+
if i != self.num_resolutions - 1:
|
257 |
+
blocks.append(Downsample(block_in_ch))
|
258 |
+
curr_res = curr_res // 2
|
259 |
+
|
260 |
+
# non-local attention block
|
261 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
262 |
+
blocks.append(AttnBlock(block_in_ch))
|
263 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
264 |
+
|
265 |
+
# normalise and convert to latent size
|
266 |
+
blocks.append(normalize(block_in_ch))
|
267 |
+
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
268 |
+
self.blocks = nn.ModuleList(blocks)
|
269 |
+
|
270 |
+
def forward(self, x):
|
271 |
+
for block in self.blocks:
|
272 |
+
x = block(x)
|
273 |
+
|
274 |
+
return x
|
275 |
+
|
276 |
+
|
277 |
+
class Generator(nn.Module):
|
278 |
+
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
279 |
+
super().__init__()
|
280 |
+
self.nf = nf
|
281 |
+
self.ch_mult = ch_mult
|
282 |
+
self.num_resolutions = len(self.ch_mult)
|
283 |
+
self.num_res_blocks = res_blocks
|
284 |
+
self.resolution = img_size
|
285 |
+
self.attn_resolutions = attn_resolutions
|
286 |
+
self.in_channels = emb_dim
|
287 |
+
self.out_channels = 3
|
288 |
+
block_in_ch = self.nf * self.ch_mult[-1]
|
289 |
+
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
290 |
+
|
291 |
+
blocks = []
|
292 |
+
# initial conv
|
293 |
+
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
294 |
+
|
295 |
+
# non-local attention block
|
296 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
297 |
+
blocks.append(AttnBlock(block_in_ch))
|
298 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
299 |
+
|
300 |
+
for i in reversed(range(self.num_resolutions)):
|
301 |
+
block_out_ch = self.nf * self.ch_mult[i]
|
302 |
+
|
303 |
+
for _ in range(self.num_res_blocks):
|
304 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
305 |
+
block_in_ch = block_out_ch
|
306 |
+
|
307 |
+
if curr_res in self.attn_resolutions:
|
308 |
+
blocks.append(AttnBlock(block_in_ch))
|
309 |
+
|
310 |
+
if i != 0:
|
311 |
+
blocks.append(Upsample(block_in_ch))
|
312 |
+
curr_res = curr_res * 2
|
313 |
+
|
314 |
+
blocks.append(normalize(block_in_ch))
|
315 |
+
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
316 |
+
|
317 |
+
self.blocks = nn.ModuleList(blocks)
|
318 |
+
|
319 |
+
|
320 |
+
def forward(self, x):
|
321 |
+
for block in self.blocks:
|
322 |
+
x = block(x)
|
323 |
+
|
324 |
+
return x
|
325 |
+
|
326 |
+
|
327 |
+
@ARCH_REGISTRY.register()
|
328 |
+
class VQAutoEncoder(nn.Module):
|
329 |
+
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
330 |
+
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
331 |
+
super().__init__()
|
332 |
+
logger = get_root_logger()
|
333 |
+
self.in_channels = 3
|
334 |
+
self.nf = nf
|
335 |
+
self.n_blocks = res_blocks
|
336 |
+
self.codebook_size = codebook_size
|
337 |
+
self.embed_dim = emb_dim
|
338 |
+
self.ch_mult = ch_mult
|
339 |
+
self.resolution = img_size
|
340 |
+
self.attn_resolutions = attn_resolutions or [16]
|
341 |
+
self.quantizer_type = quantizer
|
342 |
+
self.encoder = Encoder(
|
343 |
+
self.in_channels,
|
344 |
+
self.nf,
|
345 |
+
self.embed_dim,
|
346 |
+
self.ch_mult,
|
347 |
+
self.n_blocks,
|
348 |
+
self.resolution,
|
349 |
+
self.attn_resolutions
|
350 |
+
)
|
351 |
+
if self.quantizer_type == "nearest":
|
352 |
+
self.beta = beta #0.25
|
353 |
+
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
354 |
+
elif self.quantizer_type == "gumbel":
|
355 |
+
self.gumbel_num_hiddens = emb_dim
|
356 |
+
self.straight_through = gumbel_straight_through
|
357 |
+
self.kl_weight = gumbel_kl_weight
|
358 |
+
self.quantize = GumbelQuantizer(
|
359 |
+
self.codebook_size,
|
360 |
+
self.embed_dim,
|
361 |
+
self.gumbel_num_hiddens,
|
362 |
+
self.straight_through,
|
363 |
+
self.kl_weight
|
364 |
+
)
|
365 |
+
self.generator = Generator(
|
366 |
+
self.nf,
|
367 |
+
self.embed_dim,
|
368 |
+
self.ch_mult,
|
369 |
+
self.n_blocks,
|
370 |
+
self.resolution,
|
371 |
+
self.attn_resolutions
|
372 |
+
)
|
373 |
+
|
374 |
+
if model_path is not None:
|
375 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
376 |
+
if 'params_ema' in chkpt:
|
377 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
378 |
+
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
379 |
+
elif 'params' in chkpt:
|
380 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
381 |
+
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
382 |
+
else:
|
383 |
+
raise ValueError('Wrong params!')
|
384 |
+
|
385 |
+
|
386 |
+
def forward(self, x):
|
387 |
+
x = self.encoder(x)
|
388 |
+
quant, codebook_loss, quant_stats = self.quantize(x)
|
389 |
+
x = self.generator(quant)
|
390 |
+
return x, codebook_loss, quant_stats
|
391 |
+
|
392 |
+
|
393 |
+
|
394 |
+
# patch based discriminator
|
395 |
+
@ARCH_REGISTRY.register()
|
396 |
+
class VQGANDiscriminator(nn.Module):
|
397 |
+
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
398 |
+
super().__init__()
|
399 |
+
|
400 |
+
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
401 |
+
ndf_mult = 1
|
402 |
+
ndf_mult_prev = 1
|
403 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
404 |
+
ndf_mult_prev = ndf_mult
|
405 |
+
ndf_mult = min(2 ** n, 8)
|
406 |
+
layers += [
|
407 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
408 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
409 |
+
nn.LeakyReLU(0.2, True)
|
410 |
+
]
|
411 |
+
|
412 |
+
ndf_mult_prev = ndf_mult
|
413 |
+
ndf_mult = min(2 ** n_layers, 8)
|
414 |
+
|
415 |
+
layers += [
|
416 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
417 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
418 |
+
nn.LeakyReLU(0.2, True)
|
419 |
+
]
|
420 |
+
|
421 |
+
layers += [
|
422 |
+
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
423 |
+
self.main = nn.Sequential(*layers)
|
424 |
+
|
425 |
+
if model_path is not None:
|
426 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
427 |
+
if 'params_d' in chkpt:
|
428 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
429 |
+
elif 'params' in chkpt:
|
430 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
431 |
+
else:
|
432 |
+
raise ValueError('Wrong params!')
|
433 |
+
|
434 |
+
def forward(self, x):
|
435 |
+
return self.main(x)
|
modules/codeformer_model.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import modules.face_restoration
|
7 |
+
import modules.shared
|
8 |
+
from modules import shared, devices, modelloader, errors
|
9 |
+
from modules.paths import models_path
|
10 |
+
|
11 |
+
# codeformer people made a choice to include modified basicsr library to their project which makes
|
12 |
+
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
13 |
+
# I am making a choice to include some files from codeformer to work around this issue.
|
14 |
+
model_dir = "Codeformer"
|
15 |
+
model_path = os.path.join(models_path, model_dir)
|
16 |
+
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
17 |
+
|
18 |
+
codeformer = None
|
19 |
+
|
20 |
+
|
21 |
+
def setup_model(dirname):
|
22 |
+
os.makedirs(model_path, exist_ok=True)
|
23 |
+
|
24 |
+
path = modules.paths.paths.get("CodeFormer", None)
|
25 |
+
if path is None:
|
26 |
+
return
|
27 |
+
|
28 |
+
try:
|
29 |
+
from torchvision.transforms.functional import normalize
|
30 |
+
from modules.codeformer.codeformer_arch import CodeFormer
|
31 |
+
from basicsr.utils import img2tensor, tensor2img
|
32 |
+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
33 |
+
from facelib.detection.retinaface import retinaface
|
34 |
+
|
35 |
+
net_class = CodeFormer
|
36 |
+
|
37 |
+
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
|
38 |
+
def name(self):
|
39 |
+
return "CodeFormer"
|
40 |
+
|
41 |
+
def __init__(self, dirname):
|
42 |
+
self.net = None
|
43 |
+
self.face_helper = None
|
44 |
+
self.cmd_dir = dirname
|
45 |
+
|
46 |
+
def create_models(self):
|
47 |
+
|
48 |
+
if self.net is not None and self.face_helper is not None:
|
49 |
+
self.net.to(devices.device_codeformer)
|
50 |
+
return self.net, self.face_helper
|
51 |
+
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
|
52 |
+
if len(model_paths) != 0:
|
53 |
+
ckpt_path = model_paths[0]
|
54 |
+
else:
|
55 |
+
print("Unable to load codeformer model.")
|
56 |
+
return None, None
|
57 |
+
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
58 |
+
checkpoint = torch.load(ckpt_path)['params_ema']
|
59 |
+
net.load_state_dict(checkpoint)
|
60 |
+
net.eval()
|
61 |
+
|
62 |
+
if hasattr(retinaface, 'device'):
|
63 |
+
retinaface.device = devices.device_codeformer
|
64 |
+
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
|
65 |
+
|
66 |
+
self.net = net
|
67 |
+
self.face_helper = face_helper
|
68 |
+
|
69 |
+
return net, face_helper
|
70 |
+
|
71 |
+
def send_model_to(self, device):
|
72 |
+
self.net.to(device)
|
73 |
+
self.face_helper.face_det.to(device)
|
74 |
+
self.face_helper.face_parse.to(device)
|
75 |
+
|
76 |
+
def restore(self, np_image, w=None):
|
77 |
+
np_image = np_image[:, :, ::-1]
|
78 |
+
|
79 |
+
original_resolution = np_image.shape[0:2]
|
80 |
+
|
81 |
+
self.create_models()
|
82 |
+
if self.net is None or self.face_helper is None:
|
83 |
+
return np_image
|
84 |
+
|
85 |
+
self.send_model_to(devices.device_codeformer)
|
86 |
+
|
87 |
+
self.face_helper.clean_all()
|
88 |
+
self.face_helper.read_image(np_image)
|
89 |
+
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
90 |
+
self.face_helper.align_warp_face()
|
91 |
+
|
92 |
+
for cropped_face in self.face_helper.cropped_faces:
|
93 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
94 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
95 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
96 |
+
|
97 |
+
try:
|
98 |
+
with torch.no_grad():
|
99 |
+
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
100 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
101 |
+
del output
|
102 |
+
devices.torch_gc()
|
103 |
+
except Exception:
|
104 |
+
errors.report('Failed inference for CodeFormer', exc_info=True)
|
105 |
+
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
106 |
+
|
107 |
+
restored_face = restored_face.astype('uint8')
|
108 |
+
self.face_helper.add_restored_face(restored_face)
|
109 |
+
|
110 |
+
self.face_helper.get_inverse_affine(None)
|
111 |
+
|
112 |
+
restored_img = self.face_helper.paste_faces_to_input_image()
|
113 |
+
restored_img = restored_img[:, :, ::-1]
|
114 |
+
|
115 |
+
if original_resolution != restored_img.shape[0:2]:
|
116 |
+
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
117 |
+
|
118 |
+
self.face_helper.clean_all()
|
119 |
+
|
120 |
+
if shared.opts.face_restoration_unload:
|
121 |
+
self.send_model_to(devices.cpu)
|
122 |
+
|
123 |
+
return restored_img
|
124 |
+
|
125 |
+
global codeformer
|
126 |
+
codeformer = FaceRestorerCodeFormer(dirname)
|
127 |
+
shared.face_restorers.append(codeformer)
|
128 |
+
|
129 |
+
except Exception:
|
130 |
+
errors.report("Error setting up CodeFormer", exc_info=True)
|
131 |
+
|
132 |
+
# sys.path = stored_sys_path
|
modules/config_states.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Supports saving and restoring webui and extensions from a known working set of commits
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
import tqdm
|
9 |
+
|
10 |
+
from datetime import datetime
|
11 |
+
from collections import OrderedDict
|
12 |
+
import git
|
13 |
+
|
14 |
+
from modules import shared, extensions, errors
|
15 |
+
from modules.paths_internal import script_path, config_states_dir
|
16 |
+
|
17 |
+
|
18 |
+
all_config_states = OrderedDict()
|
19 |
+
|
20 |
+
|
21 |
+
def list_config_states():
|
22 |
+
global all_config_states
|
23 |
+
|
24 |
+
all_config_states.clear()
|
25 |
+
os.makedirs(config_states_dir, exist_ok=True)
|
26 |
+
|
27 |
+
config_states = []
|
28 |
+
for filename in os.listdir(config_states_dir):
|
29 |
+
if filename.endswith(".json"):
|
30 |
+
path = os.path.join(config_states_dir, filename)
|
31 |
+
with open(path, "r", encoding="utf-8") as f:
|
32 |
+
j = json.load(f)
|
33 |
+
j["filepath"] = path
|
34 |
+
config_states.append(j)
|
35 |
+
|
36 |
+
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
37 |
+
|
38 |
+
for cs in config_states:
|
39 |
+
timestamp = time.asctime(time.gmtime(cs["created_at"]))
|
40 |
+
name = cs.get("name", "Config")
|
41 |
+
full_name = f"{name}: {timestamp}"
|
42 |
+
all_config_states[full_name] = cs
|
43 |
+
|
44 |
+
return all_config_states
|
45 |
+
|
46 |
+
|
47 |
+
def get_webui_config():
|
48 |
+
webui_repo = None
|
49 |
+
|
50 |
+
try:
|
51 |
+
if os.path.exists(os.path.join(script_path, ".git")):
|
52 |
+
webui_repo = git.Repo(script_path)
|
53 |
+
except Exception:
|
54 |
+
errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
|
55 |
+
|
56 |
+
webui_remote = None
|
57 |
+
webui_commit_hash = None
|
58 |
+
webui_commit_date = None
|
59 |
+
webui_branch = None
|
60 |
+
if webui_repo and not webui_repo.bare:
|
61 |
+
try:
|
62 |
+
webui_remote = next(webui_repo.remote().urls, None)
|
63 |
+
head = webui_repo.head.commit
|
64 |
+
webui_commit_date = webui_repo.head.commit.committed_date
|
65 |
+
webui_commit_hash = head.hexsha
|
66 |
+
webui_branch = webui_repo.active_branch.name
|
67 |
+
|
68 |
+
except Exception:
|
69 |
+
webui_remote = None
|
70 |
+
|
71 |
+
return {
|
72 |
+
"remote": webui_remote,
|
73 |
+
"commit_hash": webui_commit_hash,
|
74 |
+
"commit_date": webui_commit_date,
|
75 |
+
"branch": webui_branch,
|
76 |
+
}
|
77 |
+
|
78 |
+
|
79 |
+
def get_extension_config():
|
80 |
+
ext_config = {}
|
81 |
+
|
82 |
+
for ext in extensions.extensions:
|
83 |
+
ext.read_info_from_repo()
|
84 |
+
|
85 |
+
entry = {
|
86 |
+
"name": ext.name,
|
87 |
+
"path": ext.path,
|
88 |
+
"enabled": ext.enabled,
|
89 |
+
"is_builtin": ext.is_builtin,
|
90 |
+
"remote": ext.remote,
|
91 |
+
"commit_hash": ext.commit_hash,
|
92 |
+
"commit_date": ext.commit_date,
|
93 |
+
"branch": ext.branch,
|
94 |
+
"have_info_from_repo": ext.have_info_from_repo
|
95 |
+
}
|
96 |
+
|
97 |
+
ext_config[ext.name] = entry
|
98 |
+
|
99 |
+
return ext_config
|
100 |
+
|
101 |
+
|
102 |
+
def get_config():
|
103 |
+
creation_time = datetime.now().timestamp()
|
104 |
+
webui_config = get_webui_config()
|
105 |
+
ext_config = get_extension_config()
|
106 |
+
|
107 |
+
return {
|
108 |
+
"created_at": creation_time,
|
109 |
+
"webui": webui_config,
|
110 |
+
"extensions": ext_config
|
111 |
+
}
|
112 |
+
|
113 |
+
|
114 |
+
def restore_webui_config(config):
|
115 |
+
print("* Restoring webui state...")
|
116 |
+
|
117 |
+
if "webui" not in config:
|
118 |
+
print("Error: No webui data saved to config")
|
119 |
+
return
|
120 |
+
|
121 |
+
webui_config = config["webui"]
|
122 |
+
|
123 |
+
if "commit_hash" not in webui_config:
|
124 |
+
print("Error: No commit saved to webui config")
|
125 |
+
return
|
126 |
+
|
127 |
+
webui_commit_hash = webui_config.get("commit_hash", None)
|
128 |
+
webui_repo = None
|
129 |
+
|
130 |
+
try:
|
131 |
+
if os.path.exists(os.path.join(script_path, ".git")):
|
132 |
+
webui_repo = git.Repo(script_path)
|
133 |
+
except Exception:
|
134 |
+
errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
|
135 |
+
return
|
136 |
+
|
137 |
+
try:
|
138 |
+
webui_repo.git.fetch(all=True)
|
139 |
+
webui_repo.git.reset(webui_commit_hash, hard=True)
|
140 |
+
print(f"* Restored webui to commit {webui_commit_hash}.")
|
141 |
+
except Exception:
|
142 |
+
errors.report(f"Error restoring webui to commit{webui_commit_hash}")
|
143 |
+
|
144 |
+
|
145 |
+
def restore_extension_config(config):
|
146 |
+
print("* Restoring extension state...")
|
147 |
+
|
148 |
+
if "extensions" not in config:
|
149 |
+
print("Error: No extension data saved to config")
|
150 |
+
return
|
151 |
+
|
152 |
+
ext_config = config["extensions"]
|
153 |
+
|
154 |
+
results = []
|
155 |
+
disabled = []
|
156 |
+
|
157 |
+
for ext in tqdm.tqdm(extensions.extensions):
|
158 |
+
if ext.is_builtin:
|
159 |
+
continue
|
160 |
+
|
161 |
+
ext.read_info_from_repo()
|
162 |
+
current_commit = ext.commit_hash
|
163 |
+
|
164 |
+
if ext.name not in ext_config:
|
165 |
+
ext.disabled = True
|
166 |
+
disabled.append(ext.name)
|
167 |
+
results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
|
168 |
+
continue
|
169 |
+
|
170 |
+
entry = ext_config[ext.name]
|
171 |
+
|
172 |
+
if "commit_hash" in entry and entry["commit_hash"]:
|
173 |
+
try:
|
174 |
+
ext.fetch_and_reset_hard(entry["commit_hash"])
|
175 |
+
ext.read_info_from_repo()
|
176 |
+
if current_commit != entry["commit_hash"]:
|
177 |
+
results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
|
178 |
+
except Exception as ex:
|
179 |
+
results.append((ext, current_commit[:8], False, ex))
|
180 |
+
else:
|
181 |
+
results.append((ext, current_commit[:8], False, "No commit hash found in config"))
|
182 |
+
|
183 |
+
if not entry.get("enabled", False):
|
184 |
+
ext.disabled = True
|
185 |
+
disabled.append(ext.name)
|
186 |
+
else:
|
187 |
+
ext.disabled = False
|
188 |
+
|
189 |
+
shared.opts.disabled_extensions = disabled
|
190 |
+
shared.opts.save(shared.config_filename)
|
191 |
+
|
192 |
+
print("* Finished restoring extensions. Results:")
|
193 |
+
for ext, prev_commit, success, result in results:
|
194 |
+
if success:
|
195 |
+
print(f" + {ext.name}: {prev_commit} -> {result}")
|
196 |
+
else:
|
197 |
+
print(f" ! {ext.name}: FAILURE ({result})")
|
modules/deepbooru.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
8 |
+
|
9 |
+
re_special = re.compile(r'([\\()])')
|
10 |
+
|
11 |
+
|
12 |
+
class DeepDanbooru:
|
13 |
+
def __init__(self):
|
14 |
+
self.model = None
|
15 |
+
|
16 |
+
def load(self):
|
17 |
+
if self.model is not None:
|
18 |
+
return
|
19 |
+
|
20 |
+
files = modelloader.load_models(
|
21 |
+
model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
|
22 |
+
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
|
23 |
+
ext_filter=[".pt"],
|
24 |
+
download_name='model-resnet_custom_v3.pt',
|
25 |
+
)
|
26 |
+
|
27 |
+
self.model = deepbooru_model.DeepDanbooruModel()
|
28 |
+
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
|
29 |
+
|
30 |
+
self.model.eval()
|
31 |
+
self.model.to(devices.cpu, devices.dtype)
|
32 |
+
|
33 |
+
def start(self):
|
34 |
+
self.load()
|
35 |
+
self.model.to(devices.device)
|
36 |
+
|
37 |
+
def stop(self):
|
38 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
39 |
+
self.model.to(devices.cpu)
|
40 |
+
devices.torch_gc()
|
41 |
+
|
42 |
+
def tag(self, pil_image):
|
43 |
+
self.start()
|
44 |
+
res = self.tag_multi(pil_image)
|
45 |
+
self.stop()
|
46 |
+
|
47 |
+
return res
|
48 |
+
|
49 |
+
def tag_multi(self, pil_image, force_disable_ranks=False):
|
50 |
+
threshold = shared.opts.interrogate_deepbooru_score_threshold
|
51 |
+
use_spaces = shared.opts.deepbooru_use_spaces
|
52 |
+
use_escape = shared.opts.deepbooru_escape
|
53 |
+
alpha_sort = shared.opts.deepbooru_sort_alpha
|
54 |
+
include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
|
55 |
+
|
56 |
+
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
57 |
+
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
58 |
+
|
59 |
+
with torch.no_grad(), devices.autocast():
|
60 |
+
x = torch.from_numpy(a).to(devices.device)
|
61 |
+
y = self.model(x)[0].detach().cpu().numpy()
|
62 |
+
|
63 |
+
probability_dict = {}
|
64 |
+
|
65 |
+
for tag, probability in zip(self.model.tags, y):
|
66 |
+
if probability < threshold:
|
67 |
+
continue
|
68 |
+
|
69 |
+
if tag.startswith("rating:"):
|
70 |
+
continue
|
71 |
+
|
72 |
+
probability_dict[tag] = probability
|
73 |
+
|
74 |
+
if alpha_sort:
|
75 |
+
tags = sorted(probability_dict)
|
76 |
+
else:
|
77 |
+
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
|
78 |
+
|
79 |
+
res = []
|
80 |
+
|
81 |
+
filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
|
82 |
+
|
83 |
+
for tag in [x for x in tags if x not in filtertags]:
|
84 |
+
probability = probability_dict[tag]
|
85 |
+
tag_outformat = tag
|
86 |
+
if use_spaces:
|
87 |
+
tag_outformat = tag_outformat.replace('_', ' ')
|
88 |
+
if use_escape:
|
89 |
+
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
90 |
+
if include_ranks:
|
91 |
+
tag_outformat = f"({tag_outformat}:{probability:.3f})"
|
92 |
+
|
93 |
+
res.append(tag_outformat)
|
94 |
+
|
95 |
+
return ", ".join(res)
|
96 |
+
|
97 |
+
|
98 |
+
model = DeepDanbooru()
|
modules/deepbooru_model.py
ADDED
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from modules import devices
|
6 |
+
|
7 |
+
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
8 |
+
|
9 |
+
|
10 |
+
class DeepDanbooruModel(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super(DeepDanbooruModel, self).__init__()
|
13 |
+
|
14 |
+
self.tags = []
|
15 |
+
|
16 |
+
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
|
17 |
+
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
|
18 |
+
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
19 |
+
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
|
20 |
+
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
21 |
+
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
22 |
+
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
23 |
+
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
24 |
+
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
25 |
+
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
26 |
+
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
27 |
+
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
28 |
+
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
|
29 |
+
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
|
30 |
+
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
|
31 |
+
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
32 |
+
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
33 |
+
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
34 |
+
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
35 |
+
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
36 |
+
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
37 |
+
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
38 |
+
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
39 |
+
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
40 |
+
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
41 |
+
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
42 |
+
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
43 |
+
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
44 |
+
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
45 |
+
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
46 |
+
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
47 |
+
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
48 |
+
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
49 |
+
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
50 |
+
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
51 |
+
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
52 |
+
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
53 |
+
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
|
54 |
+
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
|
55 |
+
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
56 |
+
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
57 |
+
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
58 |
+
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
59 |
+
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
60 |
+
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
61 |
+
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
62 |
+
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
63 |
+
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
64 |
+
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
65 |
+
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
66 |
+
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
67 |
+
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
68 |
+
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
69 |
+
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
70 |
+
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
71 |
+
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
72 |
+
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
73 |
+
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
74 |
+
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
75 |
+
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
76 |
+
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
77 |
+
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
78 |
+
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
79 |
+
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
80 |
+
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
81 |
+
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
82 |
+
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
83 |
+
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
84 |
+
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
85 |
+
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
86 |
+
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
87 |
+
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
88 |
+
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
89 |
+
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
90 |
+
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
91 |
+
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
92 |
+
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
93 |
+
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
94 |
+
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
95 |
+
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
96 |
+
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
97 |
+
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
98 |
+
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
99 |
+
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
100 |
+
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
101 |
+
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
102 |
+
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
103 |
+
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
104 |
+
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
105 |
+
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
106 |
+
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
107 |
+
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
108 |
+
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
109 |
+
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
110 |
+
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
111 |
+
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
112 |
+
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
113 |
+
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
114 |
+
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
115 |
+
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
116 |
+
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
117 |
+
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
|
118 |
+
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
119 |
+
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
120 |
+
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
121 |
+
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
122 |
+
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
123 |
+
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
124 |
+
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
125 |
+
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
126 |
+
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
127 |
+
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
128 |
+
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
129 |
+
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
130 |
+
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
131 |
+
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
132 |
+
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
133 |
+
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
134 |
+
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
135 |
+
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
136 |
+
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
137 |
+
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
138 |
+
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
139 |
+
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
140 |
+
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
141 |
+
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
142 |
+
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
143 |
+
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
144 |
+
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
145 |
+
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
146 |
+
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
147 |
+
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
148 |
+
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
149 |
+
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
150 |
+
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
151 |
+
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
152 |
+
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
153 |
+
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
154 |
+
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
155 |
+
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
156 |
+
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
157 |
+
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
158 |
+
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
159 |
+
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
160 |
+
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
161 |
+
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
162 |
+
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
163 |
+
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
164 |
+
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
165 |
+
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
166 |
+
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
167 |
+
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
168 |
+
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
169 |
+
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
170 |
+
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
171 |
+
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
172 |
+
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
173 |
+
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
174 |
+
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
175 |
+
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
|
176 |
+
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
|
177 |
+
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
|
178 |
+
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
179 |
+
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
180 |
+
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
181 |
+
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
182 |
+
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
183 |
+
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
184 |
+
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
185 |
+
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
|
186 |
+
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
|
187 |
+
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
|
188 |
+
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
189 |
+
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
190 |
+
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
191 |
+
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
192 |
+
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
193 |
+
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
194 |
+
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
195 |
+
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
|
196 |
+
|
197 |
+
def forward(self, *inputs):
|
198 |
+
t_358, = inputs
|
199 |
+
t_359 = t_358.permute(*[0, 3, 1, 2])
|
200 |
+
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
201 |
+
t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
|
202 |
+
t_361 = F.relu(t_360)
|
203 |
+
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
204 |
+
t_362 = self.n_MaxPool_0(t_361)
|
205 |
+
t_363 = self.n_Conv_1(t_362)
|
206 |
+
t_364 = self.n_Conv_2(t_362)
|
207 |
+
t_365 = F.relu(t_364)
|
208 |
+
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
|
209 |
+
t_366 = self.n_Conv_3(t_365_padded)
|
210 |
+
t_367 = F.relu(t_366)
|
211 |
+
t_368 = self.n_Conv_4(t_367)
|
212 |
+
t_369 = torch.add(t_368, t_363)
|
213 |
+
t_370 = F.relu(t_369)
|
214 |
+
t_371 = self.n_Conv_5(t_370)
|
215 |
+
t_372 = F.relu(t_371)
|
216 |
+
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
|
217 |
+
t_373 = self.n_Conv_6(t_372_padded)
|
218 |
+
t_374 = F.relu(t_373)
|
219 |
+
t_375 = self.n_Conv_7(t_374)
|
220 |
+
t_376 = torch.add(t_375, t_370)
|
221 |
+
t_377 = F.relu(t_376)
|
222 |
+
t_378 = self.n_Conv_8(t_377)
|
223 |
+
t_379 = F.relu(t_378)
|
224 |
+
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
|
225 |
+
t_380 = self.n_Conv_9(t_379_padded)
|
226 |
+
t_381 = F.relu(t_380)
|
227 |
+
t_382 = self.n_Conv_10(t_381)
|
228 |
+
t_383 = torch.add(t_382, t_377)
|
229 |
+
t_384 = F.relu(t_383)
|
230 |
+
t_385 = self.n_Conv_11(t_384)
|
231 |
+
t_386 = self.n_Conv_12(t_384)
|
232 |
+
t_387 = F.relu(t_386)
|
233 |
+
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
|
234 |
+
t_388 = self.n_Conv_13(t_387_padded)
|
235 |
+
t_389 = F.relu(t_388)
|
236 |
+
t_390 = self.n_Conv_14(t_389)
|
237 |
+
t_391 = torch.add(t_390, t_385)
|
238 |
+
t_392 = F.relu(t_391)
|
239 |
+
t_393 = self.n_Conv_15(t_392)
|
240 |
+
t_394 = F.relu(t_393)
|
241 |
+
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
|
242 |
+
t_395 = self.n_Conv_16(t_394_padded)
|
243 |
+
t_396 = F.relu(t_395)
|
244 |
+
t_397 = self.n_Conv_17(t_396)
|
245 |
+
t_398 = torch.add(t_397, t_392)
|
246 |
+
t_399 = F.relu(t_398)
|
247 |
+
t_400 = self.n_Conv_18(t_399)
|
248 |
+
t_401 = F.relu(t_400)
|
249 |
+
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
|
250 |
+
t_402 = self.n_Conv_19(t_401_padded)
|
251 |
+
t_403 = F.relu(t_402)
|
252 |
+
t_404 = self.n_Conv_20(t_403)
|
253 |
+
t_405 = torch.add(t_404, t_399)
|
254 |
+
t_406 = F.relu(t_405)
|
255 |
+
t_407 = self.n_Conv_21(t_406)
|
256 |
+
t_408 = F.relu(t_407)
|
257 |
+
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
|
258 |
+
t_409 = self.n_Conv_22(t_408_padded)
|
259 |
+
t_410 = F.relu(t_409)
|
260 |
+
t_411 = self.n_Conv_23(t_410)
|
261 |
+
t_412 = torch.add(t_411, t_406)
|
262 |
+
t_413 = F.relu(t_412)
|
263 |
+
t_414 = self.n_Conv_24(t_413)
|
264 |
+
t_415 = F.relu(t_414)
|
265 |
+
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
|
266 |
+
t_416 = self.n_Conv_25(t_415_padded)
|
267 |
+
t_417 = F.relu(t_416)
|
268 |
+
t_418 = self.n_Conv_26(t_417)
|
269 |
+
t_419 = torch.add(t_418, t_413)
|
270 |
+
t_420 = F.relu(t_419)
|
271 |
+
t_421 = self.n_Conv_27(t_420)
|
272 |
+
t_422 = F.relu(t_421)
|
273 |
+
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
|
274 |
+
t_423 = self.n_Conv_28(t_422_padded)
|
275 |
+
t_424 = F.relu(t_423)
|
276 |
+
t_425 = self.n_Conv_29(t_424)
|
277 |
+
t_426 = torch.add(t_425, t_420)
|
278 |
+
t_427 = F.relu(t_426)
|
279 |
+
t_428 = self.n_Conv_30(t_427)
|
280 |
+
t_429 = F.relu(t_428)
|
281 |
+
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
|
282 |
+
t_430 = self.n_Conv_31(t_429_padded)
|
283 |
+
t_431 = F.relu(t_430)
|
284 |
+
t_432 = self.n_Conv_32(t_431)
|
285 |
+
t_433 = torch.add(t_432, t_427)
|
286 |
+
t_434 = F.relu(t_433)
|
287 |
+
t_435 = self.n_Conv_33(t_434)
|
288 |
+
t_436 = F.relu(t_435)
|
289 |
+
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
|
290 |
+
t_437 = self.n_Conv_34(t_436_padded)
|
291 |
+
t_438 = F.relu(t_437)
|
292 |
+
t_439 = self.n_Conv_35(t_438)
|
293 |
+
t_440 = torch.add(t_439, t_434)
|
294 |
+
t_441 = F.relu(t_440)
|
295 |
+
t_442 = self.n_Conv_36(t_441)
|
296 |
+
t_443 = self.n_Conv_37(t_441)
|
297 |
+
t_444 = F.relu(t_443)
|
298 |
+
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
|
299 |
+
t_445 = self.n_Conv_38(t_444_padded)
|
300 |
+
t_446 = F.relu(t_445)
|
301 |
+
t_447 = self.n_Conv_39(t_446)
|
302 |
+
t_448 = torch.add(t_447, t_442)
|
303 |
+
t_449 = F.relu(t_448)
|
304 |
+
t_450 = self.n_Conv_40(t_449)
|
305 |
+
t_451 = F.relu(t_450)
|
306 |
+
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
|
307 |
+
t_452 = self.n_Conv_41(t_451_padded)
|
308 |
+
t_453 = F.relu(t_452)
|
309 |
+
t_454 = self.n_Conv_42(t_453)
|
310 |
+
t_455 = torch.add(t_454, t_449)
|
311 |
+
t_456 = F.relu(t_455)
|
312 |
+
t_457 = self.n_Conv_43(t_456)
|
313 |
+
t_458 = F.relu(t_457)
|
314 |
+
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
|
315 |
+
t_459 = self.n_Conv_44(t_458_padded)
|
316 |
+
t_460 = F.relu(t_459)
|
317 |
+
t_461 = self.n_Conv_45(t_460)
|
318 |
+
t_462 = torch.add(t_461, t_456)
|
319 |
+
t_463 = F.relu(t_462)
|
320 |
+
t_464 = self.n_Conv_46(t_463)
|
321 |
+
t_465 = F.relu(t_464)
|
322 |
+
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
|
323 |
+
t_466 = self.n_Conv_47(t_465_padded)
|
324 |
+
t_467 = F.relu(t_466)
|
325 |
+
t_468 = self.n_Conv_48(t_467)
|
326 |
+
t_469 = torch.add(t_468, t_463)
|
327 |
+
t_470 = F.relu(t_469)
|
328 |
+
t_471 = self.n_Conv_49(t_470)
|
329 |
+
t_472 = F.relu(t_471)
|
330 |
+
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
|
331 |
+
t_473 = self.n_Conv_50(t_472_padded)
|
332 |
+
t_474 = F.relu(t_473)
|
333 |
+
t_475 = self.n_Conv_51(t_474)
|
334 |
+
t_476 = torch.add(t_475, t_470)
|
335 |
+
t_477 = F.relu(t_476)
|
336 |
+
t_478 = self.n_Conv_52(t_477)
|
337 |
+
t_479 = F.relu(t_478)
|
338 |
+
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
|
339 |
+
t_480 = self.n_Conv_53(t_479_padded)
|
340 |
+
t_481 = F.relu(t_480)
|
341 |
+
t_482 = self.n_Conv_54(t_481)
|
342 |
+
t_483 = torch.add(t_482, t_477)
|
343 |
+
t_484 = F.relu(t_483)
|
344 |
+
t_485 = self.n_Conv_55(t_484)
|
345 |
+
t_486 = F.relu(t_485)
|
346 |
+
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
|
347 |
+
t_487 = self.n_Conv_56(t_486_padded)
|
348 |
+
t_488 = F.relu(t_487)
|
349 |
+
t_489 = self.n_Conv_57(t_488)
|
350 |
+
t_490 = torch.add(t_489, t_484)
|
351 |
+
t_491 = F.relu(t_490)
|
352 |
+
t_492 = self.n_Conv_58(t_491)
|
353 |
+
t_493 = F.relu(t_492)
|
354 |
+
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
|
355 |
+
t_494 = self.n_Conv_59(t_493_padded)
|
356 |
+
t_495 = F.relu(t_494)
|
357 |
+
t_496 = self.n_Conv_60(t_495)
|
358 |
+
t_497 = torch.add(t_496, t_491)
|
359 |
+
t_498 = F.relu(t_497)
|
360 |
+
t_499 = self.n_Conv_61(t_498)
|
361 |
+
t_500 = F.relu(t_499)
|
362 |
+
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
|
363 |
+
t_501 = self.n_Conv_62(t_500_padded)
|
364 |
+
t_502 = F.relu(t_501)
|
365 |
+
t_503 = self.n_Conv_63(t_502)
|
366 |
+
t_504 = torch.add(t_503, t_498)
|
367 |
+
t_505 = F.relu(t_504)
|
368 |
+
t_506 = self.n_Conv_64(t_505)
|
369 |
+
t_507 = F.relu(t_506)
|
370 |
+
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
|
371 |
+
t_508 = self.n_Conv_65(t_507_padded)
|
372 |
+
t_509 = F.relu(t_508)
|
373 |
+
t_510 = self.n_Conv_66(t_509)
|
374 |
+
t_511 = torch.add(t_510, t_505)
|
375 |
+
t_512 = F.relu(t_511)
|
376 |
+
t_513 = self.n_Conv_67(t_512)
|
377 |
+
t_514 = F.relu(t_513)
|
378 |
+
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
|
379 |
+
t_515 = self.n_Conv_68(t_514_padded)
|
380 |
+
t_516 = F.relu(t_515)
|
381 |
+
t_517 = self.n_Conv_69(t_516)
|
382 |
+
t_518 = torch.add(t_517, t_512)
|
383 |
+
t_519 = F.relu(t_518)
|
384 |
+
t_520 = self.n_Conv_70(t_519)
|
385 |
+
t_521 = F.relu(t_520)
|
386 |
+
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
|
387 |
+
t_522 = self.n_Conv_71(t_521_padded)
|
388 |
+
t_523 = F.relu(t_522)
|
389 |
+
t_524 = self.n_Conv_72(t_523)
|
390 |
+
t_525 = torch.add(t_524, t_519)
|
391 |
+
t_526 = F.relu(t_525)
|
392 |
+
t_527 = self.n_Conv_73(t_526)
|
393 |
+
t_528 = F.relu(t_527)
|
394 |
+
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
|
395 |
+
t_529 = self.n_Conv_74(t_528_padded)
|
396 |
+
t_530 = F.relu(t_529)
|
397 |
+
t_531 = self.n_Conv_75(t_530)
|
398 |
+
t_532 = torch.add(t_531, t_526)
|
399 |
+
t_533 = F.relu(t_532)
|
400 |
+
t_534 = self.n_Conv_76(t_533)
|
401 |
+
t_535 = F.relu(t_534)
|
402 |
+
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
|
403 |
+
t_536 = self.n_Conv_77(t_535_padded)
|
404 |
+
t_537 = F.relu(t_536)
|
405 |
+
t_538 = self.n_Conv_78(t_537)
|
406 |
+
t_539 = torch.add(t_538, t_533)
|
407 |
+
t_540 = F.relu(t_539)
|
408 |
+
t_541 = self.n_Conv_79(t_540)
|
409 |
+
t_542 = F.relu(t_541)
|
410 |
+
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
|
411 |
+
t_543 = self.n_Conv_80(t_542_padded)
|
412 |
+
t_544 = F.relu(t_543)
|
413 |
+
t_545 = self.n_Conv_81(t_544)
|
414 |
+
t_546 = torch.add(t_545, t_540)
|
415 |
+
t_547 = F.relu(t_546)
|
416 |
+
t_548 = self.n_Conv_82(t_547)
|
417 |
+
t_549 = F.relu(t_548)
|
418 |
+
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
|
419 |
+
t_550 = self.n_Conv_83(t_549_padded)
|
420 |
+
t_551 = F.relu(t_550)
|
421 |
+
t_552 = self.n_Conv_84(t_551)
|
422 |
+
t_553 = torch.add(t_552, t_547)
|
423 |
+
t_554 = F.relu(t_553)
|
424 |
+
t_555 = self.n_Conv_85(t_554)
|
425 |
+
t_556 = F.relu(t_555)
|
426 |
+
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
|
427 |
+
t_557 = self.n_Conv_86(t_556_padded)
|
428 |
+
t_558 = F.relu(t_557)
|
429 |
+
t_559 = self.n_Conv_87(t_558)
|
430 |
+
t_560 = torch.add(t_559, t_554)
|
431 |
+
t_561 = F.relu(t_560)
|
432 |
+
t_562 = self.n_Conv_88(t_561)
|
433 |
+
t_563 = F.relu(t_562)
|
434 |
+
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
|
435 |
+
t_564 = self.n_Conv_89(t_563_padded)
|
436 |
+
t_565 = F.relu(t_564)
|
437 |
+
t_566 = self.n_Conv_90(t_565)
|
438 |
+
t_567 = torch.add(t_566, t_561)
|
439 |
+
t_568 = F.relu(t_567)
|
440 |
+
t_569 = self.n_Conv_91(t_568)
|
441 |
+
t_570 = F.relu(t_569)
|
442 |
+
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
|
443 |
+
t_571 = self.n_Conv_92(t_570_padded)
|
444 |
+
t_572 = F.relu(t_571)
|
445 |
+
t_573 = self.n_Conv_93(t_572)
|
446 |
+
t_574 = torch.add(t_573, t_568)
|
447 |
+
t_575 = F.relu(t_574)
|
448 |
+
t_576 = self.n_Conv_94(t_575)
|
449 |
+
t_577 = F.relu(t_576)
|
450 |
+
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
|
451 |
+
t_578 = self.n_Conv_95(t_577_padded)
|
452 |
+
t_579 = F.relu(t_578)
|
453 |
+
t_580 = self.n_Conv_96(t_579)
|
454 |
+
t_581 = torch.add(t_580, t_575)
|
455 |
+
t_582 = F.relu(t_581)
|
456 |
+
t_583 = self.n_Conv_97(t_582)
|
457 |
+
t_584 = F.relu(t_583)
|
458 |
+
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
|
459 |
+
t_585 = self.n_Conv_98(t_584_padded)
|
460 |
+
t_586 = F.relu(t_585)
|
461 |
+
t_587 = self.n_Conv_99(t_586)
|
462 |
+
t_588 = self.n_Conv_100(t_582)
|
463 |
+
t_589 = torch.add(t_587, t_588)
|
464 |
+
t_590 = F.relu(t_589)
|
465 |
+
t_591 = self.n_Conv_101(t_590)
|
466 |
+
t_592 = F.relu(t_591)
|
467 |
+
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
|
468 |
+
t_593 = self.n_Conv_102(t_592_padded)
|
469 |
+
t_594 = F.relu(t_593)
|
470 |
+
t_595 = self.n_Conv_103(t_594)
|
471 |
+
t_596 = torch.add(t_595, t_590)
|
472 |
+
t_597 = F.relu(t_596)
|
473 |
+
t_598 = self.n_Conv_104(t_597)
|
474 |
+
t_599 = F.relu(t_598)
|
475 |
+
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
|
476 |
+
t_600 = self.n_Conv_105(t_599_padded)
|
477 |
+
t_601 = F.relu(t_600)
|
478 |
+
t_602 = self.n_Conv_106(t_601)
|
479 |
+
t_603 = torch.add(t_602, t_597)
|
480 |
+
t_604 = F.relu(t_603)
|
481 |
+
t_605 = self.n_Conv_107(t_604)
|
482 |
+
t_606 = F.relu(t_605)
|
483 |
+
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
|
484 |
+
t_607 = self.n_Conv_108(t_606_padded)
|
485 |
+
t_608 = F.relu(t_607)
|
486 |
+
t_609 = self.n_Conv_109(t_608)
|
487 |
+
t_610 = torch.add(t_609, t_604)
|
488 |
+
t_611 = F.relu(t_610)
|
489 |
+
t_612 = self.n_Conv_110(t_611)
|
490 |
+
t_613 = F.relu(t_612)
|
491 |
+
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
|
492 |
+
t_614 = self.n_Conv_111(t_613_padded)
|
493 |
+
t_615 = F.relu(t_614)
|
494 |
+
t_616 = self.n_Conv_112(t_615)
|
495 |
+
t_617 = torch.add(t_616, t_611)
|
496 |
+
t_618 = F.relu(t_617)
|
497 |
+
t_619 = self.n_Conv_113(t_618)
|
498 |
+
t_620 = F.relu(t_619)
|
499 |
+
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
|
500 |
+
t_621 = self.n_Conv_114(t_620_padded)
|
501 |
+
t_622 = F.relu(t_621)
|
502 |
+
t_623 = self.n_Conv_115(t_622)
|
503 |
+
t_624 = torch.add(t_623, t_618)
|
504 |
+
t_625 = F.relu(t_624)
|
505 |
+
t_626 = self.n_Conv_116(t_625)
|
506 |
+
t_627 = F.relu(t_626)
|
507 |
+
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
|
508 |
+
t_628 = self.n_Conv_117(t_627_padded)
|
509 |
+
t_629 = F.relu(t_628)
|
510 |
+
t_630 = self.n_Conv_118(t_629)
|
511 |
+
t_631 = torch.add(t_630, t_625)
|
512 |
+
t_632 = F.relu(t_631)
|
513 |
+
t_633 = self.n_Conv_119(t_632)
|
514 |
+
t_634 = F.relu(t_633)
|
515 |
+
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
|
516 |
+
t_635 = self.n_Conv_120(t_634_padded)
|
517 |
+
t_636 = F.relu(t_635)
|
518 |
+
t_637 = self.n_Conv_121(t_636)
|
519 |
+
t_638 = torch.add(t_637, t_632)
|
520 |
+
t_639 = F.relu(t_638)
|
521 |
+
t_640 = self.n_Conv_122(t_639)
|
522 |
+
t_641 = F.relu(t_640)
|
523 |
+
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
|
524 |
+
t_642 = self.n_Conv_123(t_641_padded)
|
525 |
+
t_643 = F.relu(t_642)
|
526 |
+
t_644 = self.n_Conv_124(t_643)
|
527 |
+
t_645 = torch.add(t_644, t_639)
|
528 |
+
t_646 = F.relu(t_645)
|
529 |
+
t_647 = self.n_Conv_125(t_646)
|
530 |
+
t_648 = F.relu(t_647)
|
531 |
+
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
|
532 |
+
t_649 = self.n_Conv_126(t_648_padded)
|
533 |
+
t_650 = F.relu(t_649)
|
534 |
+
t_651 = self.n_Conv_127(t_650)
|
535 |
+
t_652 = torch.add(t_651, t_646)
|
536 |
+
t_653 = F.relu(t_652)
|
537 |
+
t_654 = self.n_Conv_128(t_653)
|
538 |
+
t_655 = F.relu(t_654)
|
539 |
+
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
|
540 |
+
t_656 = self.n_Conv_129(t_655_padded)
|
541 |
+
t_657 = F.relu(t_656)
|
542 |
+
t_658 = self.n_Conv_130(t_657)
|
543 |
+
t_659 = torch.add(t_658, t_653)
|
544 |
+
t_660 = F.relu(t_659)
|
545 |
+
t_661 = self.n_Conv_131(t_660)
|
546 |
+
t_662 = F.relu(t_661)
|
547 |
+
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
|
548 |
+
t_663 = self.n_Conv_132(t_662_padded)
|
549 |
+
t_664 = F.relu(t_663)
|
550 |
+
t_665 = self.n_Conv_133(t_664)
|
551 |
+
t_666 = torch.add(t_665, t_660)
|
552 |
+
t_667 = F.relu(t_666)
|
553 |
+
t_668 = self.n_Conv_134(t_667)
|
554 |
+
t_669 = F.relu(t_668)
|
555 |
+
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
|
556 |
+
t_670 = self.n_Conv_135(t_669_padded)
|
557 |
+
t_671 = F.relu(t_670)
|
558 |
+
t_672 = self.n_Conv_136(t_671)
|
559 |
+
t_673 = torch.add(t_672, t_667)
|
560 |
+
t_674 = F.relu(t_673)
|
561 |
+
t_675 = self.n_Conv_137(t_674)
|
562 |
+
t_676 = F.relu(t_675)
|
563 |
+
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
|
564 |
+
t_677 = self.n_Conv_138(t_676_padded)
|
565 |
+
t_678 = F.relu(t_677)
|
566 |
+
t_679 = self.n_Conv_139(t_678)
|
567 |
+
t_680 = torch.add(t_679, t_674)
|
568 |
+
t_681 = F.relu(t_680)
|
569 |
+
t_682 = self.n_Conv_140(t_681)
|
570 |
+
t_683 = F.relu(t_682)
|
571 |
+
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
|
572 |
+
t_684 = self.n_Conv_141(t_683_padded)
|
573 |
+
t_685 = F.relu(t_684)
|
574 |
+
t_686 = self.n_Conv_142(t_685)
|
575 |
+
t_687 = torch.add(t_686, t_681)
|
576 |
+
t_688 = F.relu(t_687)
|
577 |
+
t_689 = self.n_Conv_143(t_688)
|
578 |
+
t_690 = F.relu(t_689)
|
579 |
+
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
|
580 |
+
t_691 = self.n_Conv_144(t_690_padded)
|
581 |
+
t_692 = F.relu(t_691)
|
582 |
+
t_693 = self.n_Conv_145(t_692)
|
583 |
+
t_694 = torch.add(t_693, t_688)
|
584 |
+
t_695 = F.relu(t_694)
|
585 |
+
t_696 = self.n_Conv_146(t_695)
|
586 |
+
t_697 = F.relu(t_696)
|
587 |
+
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
|
588 |
+
t_698 = self.n_Conv_147(t_697_padded)
|
589 |
+
t_699 = F.relu(t_698)
|
590 |
+
t_700 = self.n_Conv_148(t_699)
|
591 |
+
t_701 = torch.add(t_700, t_695)
|
592 |
+
t_702 = F.relu(t_701)
|
593 |
+
t_703 = self.n_Conv_149(t_702)
|
594 |
+
t_704 = F.relu(t_703)
|
595 |
+
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
|
596 |
+
t_705 = self.n_Conv_150(t_704_padded)
|
597 |
+
t_706 = F.relu(t_705)
|
598 |
+
t_707 = self.n_Conv_151(t_706)
|
599 |
+
t_708 = torch.add(t_707, t_702)
|
600 |
+
t_709 = F.relu(t_708)
|
601 |
+
t_710 = self.n_Conv_152(t_709)
|
602 |
+
t_711 = F.relu(t_710)
|
603 |
+
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
|
604 |
+
t_712 = self.n_Conv_153(t_711_padded)
|
605 |
+
t_713 = F.relu(t_712)
|
606 |
+
t_714 = self.n_Conv_154(t_713)
|
607 |
+
t_715 = torch.add(t_714, t_709)
|
608 |
+
t_716 = F.relu(t_715)
|
609 |
+
t_717 = self.n_Conv_155(t_716)
|
610 |
+
t_718 = F.relu(t_717)
|
611 |
+
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
|
612 |
+
t_719 = self.n_Conv_156(t_718_padded)
|
613 |
+
t_720 = F.relu(t_719)
|
614 |
+
t_721 = self.n_Conv_157(t_720)
|
615 |
+
t_722 = torch.add(t_721, t_716)
|
616 |
+
t_723 = F.relu(t_722)
|
617 |
+
t_724 = self.n_Conv_158(t_723)
|
618 |
+
t_725 = self.n_Conv_159(t_723)
|
619 |
+
t_726 = F.relu(t_725)
|
620 |
+
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
|
621 |
+
t_727 = self.n_Conv_160(t_726_padded)
|
622 |
+
t_728 = F.relu(t_727)
|
623 |
+
t_729 = self.n_Conv_161(t_728)
|
624 |
+
t_730 = torch.add(t_729, t_724)
|
625 |
+
t_731 = F.relu(t_730)
|
626 |
+
t_732 = self.n_Conv_162(t_731)
|
627 |
+
t_733 = F.relu(t_732)
|
628 |
+
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
|
629 |
+
t_734 = self.n_Conv_163(t_733_padded)
|
630 |
+
t_735 = F.relu(t_734)
|
631 |
+
t_736 = self.n_Conv_164(t_735)
|
632 |
+
t_737 = torch.add(t_736, t_731)
|
633 |
+
t_738 = F.relu(t_737)
|
634 |
+
t_739 = self.n_Conv_165(t_738)
|
635 |
+
t_740 = F.relu(t_739)
|
636 |
+
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
|
637 |
+
t_741 = self.n_Conv_166(t_740_padded)
|
638 |
+
t_742 = F.relu(t_741)
|
639 |
+
t_743 = self.n_Conv_167(t_742)
|
640 |
+
t_744 = torch.add(t_743, t_738)
|
641 |
+
t_745 = F.relu(t_744)
|
642 |
+
t_746 = self.n_Conv_168(t_745)
|
643 |
+
t_747 = self.n_Conv_169(t_745)
|
644 |
+
t_748 = F.relu(t_747)
|
645 |
+
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
|
646 |
+
t_749 = self.n_Conv_170(t_748_padded)
|
647 |
+
t_750 = F.relu(t_749)
|
648 |
+
t_751 = self.n_Conv_171(t_750)
|
649 |
+
t_752 = torch.add(t_751, t_746)
|
650 |
+
t_753 = F.relu(t_752)
|
651 |
+
t_754 = self.n_Conv_172(t_753)
|
652 |
+
t_755 = F.relu(t_754)
|
653 |
+
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
|
654 |
+
t_756 = self.n_Conv_173(t_755_padded)
|
655 |
+
t_757 = F.relu(t_756)
|
656 |
+
t_758 = self.n_Conv_174(t_757)
|
657 |
+
t_759 = torch.add(t_758, t_753)
|
658 |
+
t_760 = F.relu(t_759)
|
659 |
+
t_761 = self.n_Conv_175(t_760)
|
660 |
+
t_762 = F.relu(t_761)
|
661 |
+
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
|
662 |
+
t_763 = self.n_Conv_176(t_762_padded)
|
663 |
+
t_764 = F.relu(t_763)
|
664 |
+
t_765 = self.n_Conv_177(t_764)
|
665 |
+
t_766 = torch.add(t_765, t_760)
|
666 |
+
t_767 = F.relu(t_766)
|
667 |
+
t_768 = self.n_Conv_178(t_767)
|
668 |
+
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
|
669 |
+
t_770 = torch.squeeze(t_769, 3)
|
670 |
+
t_770 = torch.squeeze(t_770, 2)
|
671 |
+
t_771 = torch.sigmoid(t_770)
|
672 |
+
return t_771
|
673 |
+
|
674 |
+
def load_state_dict(self, state_dict, **kwargs):
|
675 |
+
self.tags = state_dict.get('tags', [])
|
676 |
+
|
677 |
+
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
|
678 |
+
|
modules/devices.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import contextlib
|
3 |
+
from functools import lru_cache
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from modules import errors
|
7 |
+
|
8 |
+
if sys.platform == "darwin":
|
9 |
+
from modules import mac_specific
|
10 |
+
|
11 |
+
|
12 |
+
def has_mps() -> bool:
|
13 |
+
if sys.platform != "darwin":
|
14 |
+
return False
|
15 |
+
else:
|
16 |
+
return mac_specific.has_mps
|
17 |
+
|
18 |
+
|
19 |
+
def get_cuda_device_string():
|
20 |
+
from modules import shared
|
21 |
+
|
22 |
+
if shared.cmd_opts.device_id is not None:
|
23 |
+
return f"cuda:{shared.cmd_opts.device_id}"
|
24 |
+
|
25 |
+
return "cuda"
|
26 |
+
|
27 |
+
|
28 |
+
def get_optimal_device_name():
|
29 |
+
if torch.cuda.is_available():
|
30 |
+
return get_cuda_device_string()
|
31 |
+
|
32 |
+
if has_mps():
|
33 |
+
return "mps"
|
34 |
+
|
35 |
+
return "cpu"
|
36 |
+
|
37 |
+
|
38 |
+
def get_optimal_device():
|
39 |
+
return torch.device(get_optimal_device_name())
|
40 |
+
|
41 |
+
|
42 |
+
def get_device_for(task):
|
43 |
+
from modules import shared
|
44 |
+
|
45 |
+
if task in shared.cmd_opts.use_cpu:
|
46 |
+
return cpu
|
47 |
+
|
48 |
+
return get_optimal_device()
|
49 |
+
|
50 |
+
|
51 |
+
def torch_gc():
|
52 |
+
|
53 |
+
if torch.cuda.is_available():
|
54 |
+
with torch.cuda.device(get_cuda_device_string()):
|
55 |
+
torch.cuda.empty_cache()
|
56 |
+
torch.cuda.ipc_collect()
|
57 |
+
|
58 |
+
if has_mps():
|
59 |
+
mac_specific.torch_mps_gc()
|
60 |
+
|
61 |
+
|
62 |
+
def enable_tf32():
|
63 |
+
if torch.cuda.is_available():
|
64 |
+
|
65 |
+
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
66 |
+
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
67 |
+
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
|
68 |
+
torch.backends.cudnn.benchmark = True
|
69 |
+
|
70 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
71 |
+
torch.backends.cudnn.allow_tf32 = True
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
errors.run(enable_tf32, "Enabling TF32")
|
76 |
+
|
77 |
+
cpu = torch.device("cpu")
|
78 |
+
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
79 |
+
dtype = torch.float16
|
80 |
+
dtype_vae = torch.float16
|
81 |
+
dtype_unet = torch.float16
|
82 |
+
unet_needs_upcast = False
|
83 |
+
|
84 |
+
|
85 |
+
def cond_cast_unet(input):
|
86 |
+
return input.to(dtype_unet) if unet_needs_upcast else input
|
87 |
+
|
88 |
+
|
89 |
+
def cond_cast_float(input):
|
90 |
+
return input.float() if unet_needs_upcast else input
|
91 |
+
|
92 |
+
|
93 |
+
def randn(seed, shape):
|
94 |
+
from modules.shared import opts
|
95 |
+
|
96 |
+
torch.manual_seed(seed)
|
97 |
+
if opts.randn_source == "CPU" or device.type == 'mps':
|
98 |
+
return torch.randn(shape, device=cpu).to(device)
|
99 |
+
return torch.randn(shape, device=device)
|
100 |
+
|
101 |
+
|
102 |
+
def randn_without_seed(shape):
|
103 |
+
from modules.shared import opts
|
104 |
+
|
105 |
+
if opts.randn_source == "CPU" or device.type == 'mps':
|
106 |
+
return torch.randn(shape, device=cpu).to(device)
|
107 |
+
return torch.randn(shape, device=device)
|
108 |
+
|
109 |
+
|
110 |
+
def autocast(disable=False):
|
111 |
+
from modules import shared
|
112 |
+
|
113 |
+
if disable:
|
114 |
+
return contextlib.nullcontext()
|
115 |
+
|
116 |
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
117 |
+
return contextlib.nullcontext()
|
118 |
+
|
119 |
+
return torch.autocast("cuda")
|
120 |
+
|
121 |
+
|
122 |
+
def without_autocast(disable=False):
|
123 |
+
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
124 |
+
|
125 |
+
|
126 |
+
class NansException(Exception):
|
127 |
+
pass
|
128 |
+
|
129 |
+
|
130 |
+
def test_for_nans(x, where):
|
131 |
+
from modules import shared
|
132 |
+
|
133 |
+
if shared.cmd_opts.disable_nan_check:
|
134 |
+
return
|
135 |
+
|
136 |
+
if not torch.all(torch.isnan(x)).item():
|
137 |
+
return
|
138 |
+
|
139 |
+
if where == "unet":
|
140 |
+
message = "A tensor with all NaNs was produced in Unet."
|
141 |
+
|
142 |
+
if not shared.cmd_opts.no_half:
|
143 |
+
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
|
144 |
+
|
145 |
+
elif where == "vae":
|
146 |
+
message = "A tensor with all NaNs was produced in VAE."
|
147 |
+
|
148 |
+
if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
|
149 |
+
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
|
150 |
+
else:
|
151 |
+
message = "A tensor with all NaNs was produced."
|
152 |
+
|
153 |
+
message += " Use --disable-nan-check commandline argument to disable this check."
|
154 |
+
|
155 |
+
raise NansException(message)
|
156 |
+
|
157 |
+
|
158 |
+
@lru_cache
|
159 |
+
def first_time_calculation():
|
160 |
+
"""
|
161 |
+
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
|
162 |
+
spends about 2.7 seconds doing that, at least wih NVidia.
|
163 |
+
"""
|
164 |
+
|
165 |
+
x = torch.zeros((1, 1)).to(device, dtype)
|
166 |
+
linear = torch.nn.Linear(1, 1).to(device, dtype)
|
167 |
+
linear(x)
|
168 |
+
|
169 |
+
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
170 |
+
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
171 |
+
conv2d(x)
|
modules/errors.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import textwrap
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
|
6 |
+
exception_records = []
|
7 |
+
|
8 |
+
|
9 |
+
def record_exception():
|
10 |
+
_, e, tb = sys.exc_info()
|
11 |
+
if e is None:
|
12 |
+
return
|
13 |
+
|
14 |
+
if exception_records and exception_records[-1] == e:
|
15 |
+
return
|
16 |
+
|
17 |
+
exception_records.append((e, tb))
|
18 |
+
|
19 |
+
if len(exception_records) > 5:
|
20 |
+
exception_records.pop(0)
|
21 |
+
|
22 |
+
|
23 |
+
def report(message: str, *, exc_info: bool = False) -> None:
|
24 |
+
"""
|
25 |
+
Print an error message to stderr, with optional traceback.
|
26 |
+
"""
|
27 |
+
|
28 |
+
record_exception()
|
29 |
+
|
30 |
+
for line in message.splitlines():
|
31 |
+
print("***", line, file=sys.stderr)
|
32 |
+
if exc_info:
|
33 |
+
print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
|
34 |
+
print("---", file=sys.stderr)
|
35 |
+
|
36 |
+
|
37 |
+
def print_error_explanation(message):
|
38 |
+
record_exception()
|
39 |
+
|
40 |
+
lines = message.strip().split("\n")
|
41 |
+
max_len = max([len(x) for x in lines])
|
42 |
+
|
43 |
+
print('=' * max_len, file=sys.stderr)
|
44 |
+
for line in lines:
|
45 |
+
print(line, file=sys.stderr)
|
46 |
+
print('=' * max_len, file=sys.stderr)
|
47 |
+
|
48 |
+
|
49 |
+
def display(e: Exception, task, *, full_traceback=False):
|
50 |
+
record_exception()
|
51 |
+
|
52 |
+
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
53 |
+
te = traceback.TracebackException.from_exception(e)
|
54 |
+
if full_traceback:
|
55 |
+
# include frames leading up to the try-catch block
|
56 |
+
te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
|
57 |
+
print(*te.format(), sep="", file=sys.stderr)
|
58 |
+
|
59 |
+
message = str(e)
|
60 |
+
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
61 |
+
print_error_explanation("""
|
62 |
+
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
|
63 |
+
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
64 |
+
""")
|
65 |
+
|
66 |
+
|
67 |
+
already_displayed = {}
|
68 |
+
|
69 |
+
|
70 |
+
def display_once(e: Exception, task):
|
71 |
+
record_exception()
|
72 |
+
|
73 |
+
if task in already_displayed:
|
74 |
+
return
|
75 |
+
|
76 |
+
display(e, task)
|
77 |
+
|
78 |
+
already_displayed[task] = 1
|
79 |
+
|
80 |
+
|
81 |
+
def run(code, task):
|
82 |
+
try:
|
83 |
+
code()
|
84 |
+
except Exception as e:
|
85 |
+
display(task, e)
|
modules/esrgan_model.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
import modules.esrgan_model_arch as arch
|
8 |
+
from modules import modelloader, images, devices
|
9 |
+
from modules.shared import opts
|
10 |
+
from modules.upscaler import Upscaler, UpscalerData
|
11 |
+
|
12 |
+
|
13 |
+
def mod2normal(state_dict):
|
14 |
+
# this code is copied from https://github.com/victorca25/iNNfer
|
15 |
+
if 'conv_first.weight' in state_dict:
|
16 |
+
crt_net = {}
|
17 |
+
items = list(state_dict)
|
18 |
+
|
19 |
+
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
20 |
+
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
21 |
+
|
22 |
+
for k in items.copy():
|
23 |
+
if 'RDB' in k:
|
24 |
+
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
25 |
+
if '.weight' in k:
|
26 |
+
ori_k = ori_k.replace('.weight', '.0.weight')
|
27 |
+
elif '.bias' in k:
|
28 |
+
ori_k = ori_k.replace('.bias', '.0.bias')
|
29 |
+
crt_net[ori_k] = state_dict[k]
|
30 |
+
items.remove(k)
|
31 |
+
|
32 |
+
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
|
33 |
+
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
|
34 |
+
crt_net['model.3.weight'] = state_dict['upconv1.weight']
|
35 |
+
crt_net['model.3.bias'] = state_dict['upconv1.bias']
|
36 |
+
crt_net['model.6.weight'] = state_dict['upconv2.weight']
|
37 |
+
crt_net['model.6.bias'] = state_dict['upconv2.bias']
|
38 |
+
crt_net['model.8.weight'] = state_dict['HRconv.weight']
|
39 |
+
crt_net['model.8.bias'] = state_dict['HRconv.bias']
|
40 |
+
crt_net['model.10.weight'] = state_dict['conv_last.weight']
|
41 |
+
crt_net['model.10.bias'] = state_dict['conv_last.bias']
|
42 |
+
state_dict = crt_net
|
43 |
+
return state_dict
|
44 |
+
|
45 |
+
|
46 |
+
def resrgan2normal(state_dict, nb=23):
|
47 |
+
# this code is copied from https://github.com/victorca25/iNNfer
|
48 |
+
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
49 |
+
re8x = 0
|
50 |
+
crt_net = {}
|
51 |
+
items = list(state_dict)
|
52 |
+
|
53 |
+
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
54 |
+
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
55 |
+
|
56 |
+
for k in items.copy():
|
57 |
+
if "rdb" in k:
|
58 |
+
ori_k = k.replace('body.', 'model.1.sub.')
|
59 |
+
ori_k = ori_k.replace('.rdb', '.RDB')
|
60 |
+
if '.weight' in k:
|
61 |
+
ori_k = ori_k.replace('.weight', '.0.weight')
|
62 |
+
elif '.bias' in k:
|
63 |
+
ori_k = ori_k.replace('.bias', '.0.bias')
|
64 |
+
crt_net[ori_k] = state_dict[k]
|
65 |
+
items.remove(k)
|
66 |
+
|
67 |
+
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
|
68 |
+
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
|
69 |
+
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
|
70 |
+
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
|
71 |
+
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
|
72 |
+
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
|
73 |
+
|
74 |
+
if 'conv_up3.weight' in state_dict:
|
75 |
+
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
|
76 |
+
re8x = 3
|
77 |
+
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
|
78 |
+
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
|
79 |
+
|
80 |
+
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
|
81 |
+
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
|
82 |
+
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
|
83 |
+
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
|
84 |
+
|
85 |
+
state_dict = crt_net
|
86 |
+
return state_dict
|
87 |
+
|
88 |
+
|
89 |
+
def infer_params(state_dict):
|
90 |
+
# this code is copied from https://github.com/victorca25/iNNfer
|
91 |
+
scale2x = 0
|
92 |
+
scalemin = 6
|
93 |
+
n_uplayer = 0
|
94 |
+
plus = False
|
95 |
+
|
96 |
+
for block in list(state_dict):
|
97 |
+
parts = block.split(".")
|
98 |
+
n_parts = len(parts)
|
99 |
+
if n_parts == 5 and parts[2] == "sub":
|
100 |
+
nb = int(parts[3])
|
101 |
+
elif n_parts == 3:
|
102 |
+
part_num = int(parts[1])
|
103 |
+
if (part_num > scalemin
|
104 |
+
and parts[0] == "model"
|
105 |
+
and parts[2] == "weight"):
|
106 |
+
scale2x += 1
|
107 |
+
if part_num > n_uplayer:
|
108 |
+
n_uplayer = part_num
|
109 |
+
out_nc = state_dict[block].shape[0]
|
110 |
+
if not plus and "conv1x1" in block:
|
111 |
+
plus = True
|
112 |
+
|
113 |
+
nf = state_dict["model.0.weight"].shape[0]
|
114 |
+
in_nc = state_dict["model.0.weight"].shape[1]
|
115 |
+
out_nc = out_nc
|
116 |
+
scale = 2 ** scale2x
|
117 |
+
|
118 |
+
return in_nc, out_nc, nf, nb, plus, scale
|
119 |
+
|
120 |
+
|
121 |
+
class UpscalerESRGAN(Upscaler):
|
122 |
+
def __init__(self, dirname):
|
123 |
+
self.name = "ESRGAN"
|
124 |
+
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
|
125 |
+
self.model_name = "ESRGAN_4x"
|
126 |
+
self.scalers = []
|
127 |
+
self.user_path = dirname
|
128 |
+
super().__init__()
|
129 |
+
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
130 |
+
scalers = []
|
131 |
+
if len(model_paths) == 0:
|
132 |
+
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
133 |
+
scalers.append(scaler_data)
|
134 |
+
for file in model_paths:
|
135 |
+
if file.startswith("http"):
|
136 |
+
name = self.model_name
|
137 |
+
else:
|
138 |
+
name = modelloader.friendly_name(file)
|
139 |
+
|
140 |
+
scaler_data = UpscalerData(name, file, self, 4)
|
141 |
+
self.scalers.append(scaler_data)
|
142 |
+
|
143 |
+
def do_upscale(self, img, selected_model):
|
144 |
+
try:
|
145 |
+
model = self.load_model(selected_model)
|
146 |
+
except Exception as e:
|
147 |
+
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
|
148 |
+
return img
|
149 |
+
model.to(devices.device_esrgan)
|
150 |
+
img = esrgan_upscale(model, img)
|
151 |
+
return img
|
152 |
+
|
153 |
+
def load_model(self, path: str):
|
154 |
+
if path.startswith("http"):
|
155 |
+
# TODO: this doesn't use `path` at all?
|
156 |
+
filename = modelloader.load_file_from_url(
|
157 |
+
url=self.model_url,
|
158 |
+
model_dir=self.model_download_path,
|
159 |
+
file_name=f"{self.model_name}.pth",
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
filename = path
|
163 |
+
|
164 |
+
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
165 |
+
|
166 |
+
if "params_ema" in state_dict:
|
167 |
+
state_dict = state_dict["params_ema"]
|
168 |
+
elif "params" in state_dict:
|
169 |
+
state_dict = state_dict["params"]
|
170 |
+
num_conv = 16 if "realesr-animevideov3" in filename else 32
|
171 |
+
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
|
172 |
+
model.load_state_dict(state_dict)
|
173 |
+
model.eval()
|
174 |
+
return model
|
175 |
+
|
176 |
+
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
|
177 |
+
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
|
178 |
+
state_dict = resrgan2normal(state_dict, nb)
|
179 |
+
elif "conv_first.weight" in state_dict:
|
180 |
+
state_dict = mod2normal(state_dict)
|
181 |
+
elif "model.0.weight" not in state_dict:
|
182 |
+
raise Exception("The file is not a recognized ESRGAN model.")
|
183 |
+
|
184 |
+
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
|
185 |
+
|
186 |
+
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
|
187 |
+
model.load_state_dict(state_dict)
|
188 |
+
model.eval()
|
189 |
+
|
190 |
+
return model
|
191 |
+
|
192 |
+
|
193 |
+
def upscale_without_tiling(model, img):
|
194 |
+
img = np.array(img)
|
195 |
+
img = img[:, :, ::-1]
|
196 |
+
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
197 |
+
img = torch.from_numpy(img).float()
|
198 |
+
img = img.unsqueeze(0).to(devices.device_esrgan)
|
199 |
+
with torch.no_grad():
|
200 |
+
output = model(img)
|
201 |
+
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
202 |
+
output = 255. * np.moveaxis(output, 0, 2)
|
203 |
+
output = output.astype(np.uint8)
|
204 |
+
output = output[:, :, ::-1]
|
205 |
+
return Image.fromarray(output, 'RGB')
|
206 |
+
|
207 |
+
|
208 |
+
def esrgan_upscale(model, img):
|
209 |
+
if opts.ESRGAN_tile == 0:
|
210 |
+
return upscale_without_tiling(model, img)
|
211 |
+
|
212 |
+
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
213 |
+
newtiles = []
|
214 |
+
scale_factor = 1
|
215 |
+
|
216 |
+
for y, h, row in grid.tiles:
|
217 |
+
newrow = []
|
218 |
+
for tiledata in row:
|
219 |
+
x, w, tile = tiledata
|
220 |
+
|
221 |
+
output = upscale_without_tiling(model, tile)
|
222 |
+
scale_factor = output.width // tile.width
|
223 |
+
|
224 |
+
newrow.append([x * scale_factor, w * scale_factor, output])
|
225 |
+
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
226 |
+
|
227 |
+
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
228 |
+
output = images.combine_grid(newgrid)
|
229 |
+
return output
|
modules/esrgan_model_arch.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this file is adapted from https://github.com/victorca25/iNNfer
|
2 |
+
|
3 |
+
from collections import OrderedDict
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
####################
|
11 |
+
# RRDBNet Generator
|
12 |
+
####################
|
13 |
+
|
14 |
+
class RRDBNet(nn.Module):
|
15 |
+
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
16 |
+
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
17 |
+
finalact=None, gaussian_noise=False, plus=False):
|
18 |
+
super(RRDBNet, self).__init__()
|
19 |
+
n_upscale = int(math.log(upscale, 2))
|
20 |
+
if upscale == 3:
|
21 |
+
n_upscale = 1
|
22 |
+
|
23 |
+
self.resrgan_scale = 0
|
24 |
+
if in_nc % 16 == 0:
|
25 |
+
self.resrgan_scale = 1
|
26 |
+
elif in_nc != 4 and in_nc % 4 == 0:
|
27 |
+
self.resrgan_scale = 2
|
28 |
+
|
29 |
+
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
30 |
+
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
31 |
+
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
32 |
+
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
33 |
+
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
34 |
+
|
35 |
+
if upsample_mode == 'upconv':
|
36 |
+
upsample_block = upconv_block
|
37 |
+
elif upsample_mode == 'pixelshuffle':
|
38 |
+
upsample_block = pixelshuffle_block
|
39 |
+
else:
|
40 |
+
raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
|
41 |
+
if upscale == 3:
|
42 |
+
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
43 |
+
else:
|
44 |
+
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
45 |
+
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
46 |
+
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
47 |
+
|
48 |
+
outact = act(finalact) if finalact else None
|
49 |
+
|
50 |
+
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
51 |
+
*upsampler, HR_conv0, HR_conv1, outact)
|
52 |
+
|
53 |
+
def forward(self, x, outm=None):
|
54 |
+
if self.resrgan_scale == 1:
|
55 |
+
feat = pixel_unshuffle(x, scale=4)
|
56 |
+
elif self.resrgan_scale == 2:
|
57 |
+
feat = pixel_unshuffle(x, scale=2)
|
58 |
+
else:
|
59 |
+
feat = x
|
60 |
+
|
61 |
+
return self.model(feat)
|
62 |
+
|
63 |
+
|
64 |
+
class RRDB(nn.Module):
|
65 |
+
"""
|
66 |
+
Residual in Residual Dense Block
|
67 |
+
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
71 |
+
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
72 |
+
spectral_norm=False, gaussian_noise=False, plus=False):
|
73 |
+
super(RRDB, self).__init__()
|
74 |
+
# This is for backwards compatibility with existing models
|
75 |
+
if nr == 3:
|
76 |
+
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
77 |
+
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
78 |
+
gaussian_noise=gaussian_noise, plus=plus)
|
79 |
+
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
80 |
+
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
81 |
+
gaussian_noise=gaussian_noise, plus=plus)
|
82 |
+
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
83 |
+
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
84 |
+
gaussian_noise=gaussian_noise, plus=plus)
|
85 |
+
else:
|
86 |
+
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
87 |
+
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
88 |
+
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
89 |
+
self.RDBs = nn.Sequential(*RDB_list)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
if hasattr(self, 'RDB1'):
|
93 |
+
out = self.RDB1(x)
|
94 |
+
out = self.RDB2(out)
|
95 |
+
out = self.RDB3(out)
|
96 |
+
else:
|
97 |
+
out = self.RDBs(x)
|
98 |
+
return out * 0.2 + x
|
99 |
+
|
100 |
+
|
101 |
+
class ResidualDenseBlock_5C(nn.Module):
|
102 |
+
"""
|
103 |
+
Residual Dense Block
|
104 |
+
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
105 |
+
Modified options that can be used:
|
106 |
+
- "Partial Convolution based Padding" arXiv:1811.11718
|
107 |
+
- "Spectral normalization" arXiv:1802.05957
|
108 |
+
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
109 |
+
{Rakotonirina} and A. {Rasoanaivo}
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
113 |
+
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
114 |
+
spectral_norm=False, gaussian_noise=False, plus=False):
|
115 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
116 |
+
|
117 |
+
self.noise = GaussianNoise() if gaussian_noise else None
|
118 |
+
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
119 |
+
|
120 |
+
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
121 |
+
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
122 |
+
spectral_norm=spectral_norm)
|
123 |
+
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
124 |
+
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
125 |
+
spectral_norm=spectral_norm)
|
126 |
+
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
127 |
+
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
128 |
+
spectral_norm=spectral_norm)
|
129 |
+
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
130 |
+
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
131 |
+
spectral_norm=spectral_norm)
|
132 |
+
if mode == 'CNA':
|
133 |
+
last_act = None
|
134 |
+
else:
|
135 |
+
last_act = act_type
|
136 |
+
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
137 |
+
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
138 |
+
spectral_norm=spectral_norm)
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
x1 = self.conv1(x)
|
142 |
+
x2 = self.conv2(torch.cat((x, x1), 1))
|
143 |
+
if self.conv1x1:
|
144 |
+
x2 = x2 + self.conv1x1(x)
|
145 |
+
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
146 |
+
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
147 |
+
if self.conv1x1:
|
148 |
+
x4 = x4 + x2
|
149 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
150 |
+
if self.noise:
|
151 |
+
return self.noise(x5.mul(0.2) + x)
|
152 |
+
else:
|
153 |
+
return x5 * 0.2 + x
|
154 |
+
|
155 |
+
|
156 |
+
####################
|
157 |
+
# ESRGANplus
|
158 |
+
####################
|
159 |
+
|
160 |
+
class GaussianNoise(nn.Module):
|
161 |
+
def __init__(self, sigma=0.1, is_relative_detach=False):
|
162 |
+
super().__init__()
|
163 |
+
self.sigma = sigma
|
164 |
+
self.is_relative_detach = is_relative_detach
|
165 |
+
self.noise = torch.tensor(0, dtype=torch.float)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
if self.training and self.sigma != 0:
|
169 |
+
self.noise = self.noise.to(x.device)
|
170 |
+
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
171 |
+
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
172 |
+
x = x + sampled_noise
|
173 |
+
return x
|
174 |
+
|
175 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
176 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
177 |
+
|
178 |
+
|
179 |
+
####################
|
180 |
+
# SRVGGNetCompact
|
181 |
+
####################
|
182 |
+
|
183 |
+
class SRVGGNetCompact(nn.Module):
|
184 |
+
"""A compact VGG-style network structure for super-resolution.
|
185 |
+
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
186 |
+
"""
|
187 |
+
|
188 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
189 |
+
super(SRVGGNetCompact, self).__init__()
|
190 |
+
self.num_in_ch = num_in_ch
|
191 |
+
self.num_out_ch = num_out_ch
|
192 |
+
self.num_feat = num_feat
|
193 |
+
self.num_conv = num_conv
|
194 |
+
self.upscale = upscale
|
195 |
+
self.act_type = act_type
|
196 |
+
|
197 |
+
self.body = nn.ModuleList()
|
198 |
+
# the first conv
|
199 |
+
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
200 |
+
# the first activation
|
201 |
+
if act_type == 'relu':
|
202 |
+
activation = nn.ReLU(inplace=True)
|
203 |
+
elif act_type == 'prelu':
|
204 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
205 |
+
elif act_type == 'leakyrelu':
|
206 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
207 |
+
self.body.append(activation)
|
208 |
+
|
209 |
+
# the body structure
|
210 |
+
for _ in range(num_conv):
|
211 |
+
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
212 |
+
# activation
|
213 |
+
if act_type == 'relu':
|
214 |
+
activation = nn.ReLU(inplace=True)
|
215 |
+
elif act_type == 'prelu':
|
216 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
217 |
+
elif act_type == 'leakyrelu':
|
218 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
219 |
+
self.body.append(activation)
|
220 |
+
|
221 |
+
# the last conv
|
222 |
+
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
223 |
+
# upsample
|
224 |
+
self.upsampler = nn.PixelShuffle(upscale)
|
225 |
+
|
226 |
+
def forward(self, x):
|
227 |
+
out = x
|
228 |
+
for i in range(0, len(self.body)):
|
229 |
+
out = self.body[i](out)
|
230 |
+
|
231 |
+
out = self.upsampler(out)
|
232 |
+
# add the nearest upsampled image, so that the network learns the residual
|
233 |
+
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
234 |
+
out += base
|
235 |
+
return out
|
236 |
+
|
237 |
+
|
238 |
+
####################
|
239 |
+
# Upsampler
|
240 |
+
####################
|
241 |
+
|
242 |
+
class Upsample(nn.Module):
|
243 |
+
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
244 |
+
The input data is assumed to be of the form
|
245 |
+
`minibatch x channels x [optional depth] x [optional height] x width`.
|
246 |
+
"""
|
247 |
+
|
248 |
+
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
249 |
+
super(Upsample, self).__init__()
|
250 |
+
if isinstance(scale_factor, tuple):
|
251 |
+
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
252 |
+
else:
|
253 |
+
self.scale_factor = float(scale_factor) if scale_factor else None
|
254 |
+
self.mode = mode
|
255 |
+
self.size = size
|
256 |
+
self.align_corners = align_corners
|
257 |
+
|
258 |
+
def forward(self, x):
|
259 |
+
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
260 |
+
|
261 |
+
def extra_repr(self):
|
262 |
+
if self.scale_factor is not None:
|
263 |
+
info = f'scale_factor={self.scale_factor}'
|
264 |
+
else:
|
265 |
+
info = f'size={self.size}'
|
266 |
+
info += f', mode={self.mode}'
|
267 |
+
return info
|
268 |
+
|
269 |
+
|
270 |
+
def pixel_unshuffle(x, scale):
|
271 |
+
""" Pixel unshuffle.
|
272 |
+
Args:
|
273 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
274 |
+
scale (int): Downsample ratio.
|
275 |
+
Returns:
|
276 |
+
Tensor: the pixel unshuffled feature.
|
277 |
+
"""
|
278 |
+
b, c, hh, hw = x.size()
|
279 |
+
out_channel = c * (scale**2)
|
280 |
+
assert hh % scale == 0 and hw % scale == 0
|
281 |
+
h = hh // scale
|
282 |
+
w = hw // scale
|
283 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
284 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
285 |
+
|
286 |
+
|
287 |
+
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
288 |
+
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
289 |
+
"""
|
290 |
+
Pixel shuffle layer
|
291 |
+
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
292 |
+
Neural Network, CVPR17)
|
293 |
+
"""
|
294 |
+
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
295 |
+
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
296 |
+
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
297 |
+
|
298 |
+
n = norm(norm_type, out_nc) if norm_type else None
|
299 |
+
a = act(act_type) if act_type else None
|
300 |
+
return sequential(conv, pixel_shuffle, n, a)
|
301 |
+
|
302 |
+
|
303 |
+
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
304 |
+
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
305 |
+
""" Upconv layer """
|
306 |
+
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
307 |
+
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
308 |
+
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
309 |
+
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
310 |
+
return sequential(upsample, conv)
|
311 |
+
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
|
319 |
+
####################
|
320 |
+
# Basic blocks
|
321 |
+
####################
|
322 |
+
|
323 |
+
|
324 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
325 |
+
"""Make layers by stacking the same blocks.
|
326 |
+
Args:
|
327 |
+
basic_block (nn.module): nn.module class for basic block. (block)
|
328 |
+
num_basic_block (int): number of blocks. (n_layers)
|
329 |
+
Returns:
|
330 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
331 |
+
"""
|
332 |
+
layers = []
|
333 |
+
for _ in range(num_basic_block):
|
334 |
+
layers.append(basic_block(**kwarg))
|
335 |
+
return nn.Sequential(*layers)
|
336 |
+
|
337 |
+
|
338 |
+
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
339 |
+
""" activation helper """
|
340 |
+
act_type = act_type.lower()
|
341 |
+
if act_type == 'relu':
|
342 |
+
layer = nn.ReLU(inplace)
|
343 |
+
elif act_type in ('leakyrelu', 'lrelu'):
|
344 |
+
layer = nn.LeakyReLU(neg_slope, inplace)
|
345 |
+
elif act_type == 'prelu':
|
346 |
+
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
347 |
+
elif act_type == 'tanh': # [-1, 1] range output
|
348 |
+
layer = nn.Tanh()
|
349 |
+
elif act_type == 'sigmoid': # [0, 1] range output
|
350 |
+
layer = nn.Sigmoid()
|
351 |
+
else:
|
352 |
+
raise NotImplementedError(f'activation layer [{act_type}] is not found')
|
353 |
+
return layer
|
354 |
+
|
355 |
+
|
356 |
+
class Identity(nn.Module):
|
357 |
+
def __init__(self, *kwargs):
|
358 |
+
super(Identity, self).__init__()
|
359 |
+
|
360 |
+
def forward(self, x, *kwargs):
|
361 |
+
return x
|
362 |
+
|
363 |
+
|
364 |
+
def norm(norm_type, nc):
|
365 |
+
""" Return a normalization layer """
|
366 |
+
norm_type = norm_type.lower()
|
367 |
+
if norm_type == 'batch':
|
368 |
+
layer = nn.BatchNorm2d(nc, affine=True)
|
369 |
+
elif norm_type == 'instance':
|
370 |
+
layer = nn.InstanceNorm2d(nc, affine=False)
|
371 |
+
elif norm_type == 'none':
|
372 |
+
def norm_layer(x): return Identity()
|
373 |
+
else:
|
374 |
+
raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
|
375 |
+
return layer
|
376 |
+
|
377 |
+
|
378 |
+
def pad(pad_type, padding):
|
379 |
+
""" padding layer helper """
|
380 |
+
pad_type = pad_type.lower()
|
381 |
+
if padding == 0:
|
382 |
+
return None
|
383 |
+
if pad_type == 'reflect':
|
384 |
+
layer = nn.ReflectionPad2d(padding)
|
385 |
+
elif pad_type == 'replicate':
|
386 |
+
layer = nn.ReplicationPad2d(padding)
|
387 |
+
elif pad_type == 'zero':
|
388 |
+
layer = nn.ZeroPad2d(padding)
|
389 |
+
else:
|
390 |
+
raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
|
391 |
+
return layer
|
392 |
+
|
393 |
+
|
394 |
+
def get_valid_padding(kernel_size, dilation):
|
395 |
+
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
396 |
+
padding = (kernel_size - 1) // 2
|
397 |
+
return padding
|
398 |
+
|
399 |
+
|
400 |
+
class ShortcutBlock(nn.Module):
|
401 |
+
""" Elementwise sum the output of a submodule to its input """
|
402 |
+
def __init__(self, submodule):
|
403 |
+
super(ShortcutBlock, self).__init__()
|
404 |
+
self.sub = submodule
|
405 |
+
|
406 |
+
def forward(self, x):
|
407 |
+
output = x + self.sub(x)
|
408 |
+
return output
|
409 |
+
|
410 |
+
def __repr__(self):
|
411 |
+
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
412 |
+
|
413 |
+
|
414 |
+
def sequential(*args):
|
415 |
+
""" Flatten Sequential. It unwraps nn.Sequential. """
|
416 |
+
if len(args) == 1:
|
417 |
+
if isinstance(args[0], OrderedDict):
|
418 |
+
raise NotImplementedError('sequential does not support OrderedDict input.')
|
419 |
+
return args[0] # No sequential is needed.
|
420 |
+
modules = []
|
421 |
+
for module in args:
|
422 |
+
if isinstance(module, nn.Sequential):
|
423 |
+
for submodule in module.children():
|
424 |
+
modules.append(submodule)
|
425 |
+
elif isinstance(module, nn.Module):
|
426 |
+
modules.append(module)
|
427 |
+
return nn.Sequential(*modules)
|
428 |
+
|
429 |
+
|
430 |
+
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
431 |
+
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
432 |
+
spectral_norm=False):
|
433 |
+
""" Conv layer with padding, normalization, activation """
|
434 |
+
assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
|
435 |
+
padding = get_valid_padding(kernel_size, dilation)
|
436 |
+
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
437 |
+
padding = padding if pad_type == 'zero' else 0
|
438 |
+
|
439 |
+
if convtype=='PartialConv2D':
|
440 |
+
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
|
441 |
+
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
442 |
+
dilation=dilation, bias=bias, groups=groups)
|
443 |
+
elif convtype=='DeformConv2D':
|
444 |
+
from torchvision.ops import DeformConv2d # not tested
|
445 |
+
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
446 |
+
dilation=dilation, bias=bias, groups=groups)
|
447 |
+
elif convtype=='Conv3D':
|
448 |
+
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
449 |
+
dilation=dilation, bias=bias, groups=groups)
|
450 |
+
else:
|
451 |
+
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
452 |
+
dilation=dilation, bias=bias, groups=groups)
|
453 |
+
|
454 |
+
if spectral_norm:
|
455 |
+
c = nn.utils.spectral_norm(c)
|
456 |
+
|
457 |
+
a = act(act_type) if act_type else None
|
458 |
+
if 'CNA' in mode:
|
459 |
+
n = norm(norm_type, out_nc) if norm_type else None
|
460 |
+
return sequential(p, c, n, a)
|
461 |
+
elif mode == 'NAC':
|
462 |
+
if norm_type is None and act_type is not None:
|
463 |
+
a = act(act_type, inplace=False)
|
464 |
+
n = norm(norm_type, in_nc) if norm_type else None
|
465 |
+
return sequential(n, a, p, c)
|
modules/extensions.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import threading
|
3 |
+
|
4 |
+
from modules import shared, errors, cache
|
5 |
+
from modules.gitpython_hack import Repo
|
6 |
+
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
7 |
+
|
8 |
+
extensions = []
|
9 |
+
|
10 |
+
os.makedirs(extensions_dir, exist_ok=True)
|
11 |
+
|
12 |
+
|
13 |
+
def active():
|
14 |
+
if shared.opts.disable_all_extensions == "all":
|
15 |
+
return []
|
16 |
+
elif shared.opts.disable_all_extensions == "extra":
|
17 |
+
return [x for x in extensions if x.enabled and x.is_builtin]
|
18 |
+
else:
|
19 |
+
return [x for x in extensions if x.enabled]
|
20 |
+
|
21 |
+
|
22 |
+
class Extension:
|
23 |
+
lock = threading.Lock()
|
24 |
+
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
|
25 |
+
|
26 |
+
def __init__(self, name, path, enabled=True, is_builtin=False):
|
27 |
+
self.name = name
|
28 |
+
self.path = path
|
29 |
+
self.enabled = enabled
|
30 |
+
self.status = ''
|
31 |
+
self.can_update = False
|
32 |
+
self.is_builtin = is_builtin
|
33 |
+
self.commit_hash = ''
|
34 |
+
self.commit_date = None
|
35 |
+
self.version = ''
|
36 |
+
self.branch = None
|
37 |
+
self.remote = None
|
38 |
+
self.have_info_from_repo = False
|
39 |
+
|
40 |
+
def to_dict(self):
|
41 |
+
return {x: getattr(self, x) for x in self.cached_fields}
|
42 |
+
|
43 |
+
def from_dict(self, d):
|
44 |
+
for field in self.cached_fields:
|
45 |
+
setattr(self, field, d[field])
|
46 |
+
|
47 |
+
def read_info_from_repo(self):
|
48 |
+
if self.is_builtin or self.have_info_from_repo:
|
49 |
+
return
|
50 |
+
|
51 |
+
def read_from_repo():
|
52 |
+
with self.lock:
|
53 |
+
if self.have_info_from_repo:
|
54 |
+
return
|
55 |
+
|
56 |
+
self.do_read_info_from_repo()
|
57 |
+
|
58 |
+
return self.to_dict()
|
59 |
+
try:
|
60 |
+
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
|
61 |
+
self.from_dict(d)
|
62 |
+
except FileNotFoundError:
|
63 |
+
pass
|
64 |
+
self.status = 'unknown'
|
65 |
+
|
66 |
+
def do_read_info_from_repo(self):
|
67 |
+
repo = None
|
68 |
+
try:
|
69 |
+
if os.path.exists(os.path.join(self.path, ".git")):
|
70 |
+
repo = Repo(self.path)
|
71 |
+
except Exception:
|
72 |
+
errors.report(f"Error reading github repository info from {self.path}", exc_info=True)
|
73 |
+
|
74 |
+
if repo is None or repo.bare:
|
75 |
+
self.remote = None
|
76 |
+
else:
|
77 |
+
try:
|
78 |
+
self.remote = next(repo.remote().urls, None)
|
79 |
+
commit = repo.head.commit
|
80 |
+
self.commit_date = commit.committed_date
|
81 |
+
if repo.active_branch:
|
82 |
+
self.branch = repo.active_branch.name
|
83 |
+
self.commit_hash = commit.hexsha
|
84 |
+
self.version = self.commit_hash[:8]
|
85 |
+
|
86 |
+
except Exception:
|
87 |
+
errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
|
88 |
+
self.remote = None
|
89 |
+
|
90 |
+
self.have_info_from_repo = True
|
91 |
+
|
92 |
+
def list_files(self, subdir, extension):
|
93 |
+
from modules import scripts
|
94 |
+
|
95 |
+
dirpath = os.path.join(self.path, subdir)
|
96 |
+
if not os.path.isdir(dirpath):
|
97 |
+
return []
|
98 |
+
|
99 |
+
res = []
|
100 |
+
for filename in sorted(os.listdir(dirpath)):
|
101 |
+
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
|
102 |
+
|
103 |
+
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
104 |
+
|
105 |
+
return res
|
106 |
+
|
107 |
+
def check_updates(self):
|
108 |
+
repo = Repo(self.path)
|
109 |
+
for fetch in repo.remote().fetch(dry_run=True):
|
110 |
+
if fetch.flags != fetch.HEAD_UPTODATE:
|
111 |
+
self.can_update = True
|
112 |
+
self.status = "new commits"
|
113 |
+
return
|
114 |
+
|
115 |
+
try:
|
116 |
+
origin = repo.rev_parse('origin')
|
117 |
+
if repo.head.commit != origin:
|
118 |
+
self.can_update = True
|
119 |
+
self.status = "behind HEAD"
|
120 |
+
return
|
121 |
+
except Exception:
|
122 |
+
self.can_update = False
|
123 |
+
self.status = "unknown (remote error)"
|
124 |
+
return
|
125 |
+
|
126 |
+
self.can_update = False
|
127 |
+
self.status = "latest"
|
128 |
+
|
129 |
+
def fetch_and_reset_hard(self, commit='origin'):
|
130 |
+
repo = Repo(self.path)
|
131 |
+
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
132 |
+
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
133 |
+
repo.git.fetch(all=True)
|
134 |
+
repo.git.reset(commit, hard=True)
|
135 |
+
self.have_info_from_repo = False
|
136 |
+
|
137 |
+
|
138 |
+
def list_extensions():
|
139 |
+
extensions.clear()
|
140 |
+
|
141 |
+
if not os.path.isdir(extensions_dir):
|
142 |
+
return
|
143 |
+
|
144 |
+
if shared.opts.disable_all_extensions == "all":
|
145 |
+
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
|
146 |
+
elif shared.opts.disable_all_extensions == "extra":
|
147 |
+
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
148 |
+
|
149 |
+
extension_paths = []
|
150 |
+
for dirname in [extensions_dir, extensions_builtin_dir]:
|
151 |
+
if not os.path.isdir(dirname):
|
152 |
+
return
|
153 |
+
|
154 |
+
for extension_dirname in sorted(os.listdir(dirname)):
|
155 |
+
path = os.path.join(dirname, extension_dirname)
|
156 |
+
if not os.path.isdir(path):
|
157 |
+
continue
|
158 |
+
|
159 |
+
extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
160 |
+
|
161 |
+
for dirname, path, is_builtin in extension_paths:
|
162 |
+
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
163 |
+
extensions.append(extension)
|
modules/extra_networks.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
from modules import errors
|
5 |
+
|
6 |
+
extra_network_registry = {}
|
7 |
+
extra_network_aliases = {}
|
8 |
+
|
9 |
+
|
10 |
+
def initialize():
|
11 |
+
extra_network_registry.clear()
|
12 |
+
extra_network_aliases.clear()
|
13 |
+
|
14 |
+
|
15 |
+
def register_extra_network(extra_network):
|
16 |
+
extra_network_registry[extra_network.name] = extra_network
|
17 |
+
|
18 |
+
|
19 |
+
def register_extra_network_alias(extra_network, alias):
|
20 |
+
extra_network_aliases[alias] = extra_network
|
21 |
+
|
22 |
+
|
23 |
+
def register_default_extra_networks():
|
24 |
+
from modules.extra_networks_hypernet import ExtraNetworkHypernet
|
25 |
+
register_extra_network(ExtraNetworkHypernet())
|
26 |
+
|
27 |
+
|
28 |
+
class ExtraNetworkParams:
|
29 |
+
def __init__(self, items=None):
|
30 |
+
self.items = items or []
|
31 |
+
self.positional = []
|
32 |
+
self.named = {}
|
33 |
+
|
34 |
+
for item in self.items:
|
35 |
+
parts = item.split('=', 2) if isinstance(item, str) else [item]
|
36 |
+
if len(parts) == 2:
|
37 |
+
self.named[parts[0]] = parts[1]
|
38 |
+
else:
|
39 |
+
self.positional.append(item)
|
40 |
+
|
41 |
+
def __eq__(self, other):
|
42 |
+
return self.items == other.items
|
43 |
+
|
44 |
+
|
45 |
+
class ExtraNetwork:
|
46 |
+
def __init__(self, name):
|
47 |
+
self.name = name
|
48 |
+
|
49 |
+
def activate(self, p, params_list):
|
50 |
+
"""
|
51 |
+
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
|
52 |
+
Passes arguments related to this extra network in params_list.
|
53 |
+
User passes arguments by specifying this in his prompt:
|
54 |
+
|
55 |
+
<name:arg1:arg2:arg3>
|
56 |
+
|
57 |
+
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
|
58 |
+
separated by colon.
|
59 |
+
|
60 |
+
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
|
61 |
+
in this case, all effects of this extra networks should be disabled.
|
62 |
+
|
63 |
+
Can be called multiple times before deactivate() - each new call should override the previous call completely.
|
64 |
+
|
65 |
+
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
|
66 |
+
|
67 |
+
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
|
68 |
+
|
69 |
+
params_list will be:
|
70 |
+
|
71 |
+
[
|
72 |
+
ExtraNetworkParams(items=["agm", "1.1"]),
|
73 |
+
ExtraNetworkParams(items=["ray"])
|
74 |
+
]
|
75 |
+
|
76 |
+
"""
|
77 |
+
raise NotImplementedError
|
78 |
+
|
79 |
+
def deactivate(self, p):
|
80 |
+
"""
|
81 |
+
Called at the end of processing for housekeeping. No need to do anything here.
|
82 |
+
"""
|
83 |
+
|
84 |
+
raise NotImplementedError
|
85 |
+
|
86 |
+
|
87 |
+
def activate(p, extra_network_data):
|
88 |
+
"""call activate for extra networks in extra_network_data in specified order, then call
|
89 |
+
activate for all remaining registered networks with an empty argument list"""
|
90 |
+
|
91 |
+
activated = []
|
92 |
+
|
93 |
+
for extra_network_name, extra_network_args in extra_network_data.items():
|
94 |
+
extra_network = extra_network_registry.get(extra_network_name, None)
|
95 |
+
|
96 |
+
if extra_network is None:
|
97 |
+
extra_network = extra_network_aliases.get(extra_network_name, None)
|
98 |
+
|
99 |
+
if extra_network is None:
|
100 |
+
print(f"Skipping unknown extra network: {extra_network_name}")
|
101 |
+
continue
|
102 |
+
|
103 |
+
try:
|
104 |
+
extra_network.activate(p, extra_network_args)
|
105 |
+
activated.append(extra_network)
|
106 |
+
except Exception as e:
|
107 |
+
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
|
108 |
+
|
109 |
+
for extra_network_name, extra_network in extra_network_registry.items():
|
110 |
+
if extra_network in activated:
|
111 |
+
continue
|
112 |
+
|
113 |
+
try:
|
114 |
+
extra_network.activate(p, [])
|
115 |
+
except Exception as e:
|
116 |
+
errors.display(e, f"activating extra network {extra_network_name}")
|
117 |
+
|
118 |
+
if p.scripts is not None:
|
119 |
+
p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
|
120 |
+
|
121 |
+
|
122 |
+
def deactivate(p, extra_network_data):
|
123 |
+
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
124 |
+
deactivate for all remaining registered networks"""
|
125 |
+
|
126 |
+
for extra_network_name in extra_network_data:
|
127 |
+
extra_network = extra_network_registry.get(extra_network_name, None)
|
128 |
+
if extra_network is None:
|
129 |
+
continue
|
130 |
+
|
131 |
+
try:
|
132 |
+
extra_network.deactivate(p)
|
133 |
+
except Exception as e:
|
134 |
+
errors.display(e, f"deactivating extra network {extra_network_name}")
|
135 |
+
|
136 |
+
for extra_network_name, extra_network in extra_network_registry.items():
|
137 |
+
args = extra_network_data.get(extra_network_name, None)
|
138 |
+
if args is not None:
|
139 |
+
continue
|
140 |
+
|
141 |
+
try:
|
142 |
+
extra_network.deactivate(p)
|
143 |
+
except Exception as e:
|
144 |
+
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
|
145 |
+
|
146 |
+
|
147 |
+
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
|
148 |
+
|
149 |
+
|
150 |
+
def parse_prompt(prompt):
|
151 |
+
res = defaultdict(list)
|
152 |
+
|
153 |
+
def found(m):
|
154 |
+
name = m.group(1)
|
155 |
+
args = m.group(2)
|
156 |
+
|
157 |
+
res[name].append(ExtraNetworkParams(items=args.split(":")))
|
158 |
+
|
159 |
+
return ""
|
160 |
+
|
161 |
+
prompt = re.sub(re_extra_net, found, prompt)
|
162 |
+
|
163 |
+
return prompt, res
|
164 |
+
|
165 |
+
|
166 |
+
def parse_prompts(prompts):
|
167 |
+
res = []
|
168 |
+
extra_data = None
|
169 |
+
|
170 |
+
for prompt in prompts:
|
171 |
+
updated_prompt, parsed_extra_data = parse_prompt(prompt)
|
172 |
+
|
173 |
+
if extra_data is None:
|
174 |
+
extra_data = parsed_extra_data
|
175 |
+
|
176 |
+
res.append(updated_prompt)
|
177 |
+
|
178 |
+
return res, extra_data
|
179 |
+
|
modules/extra_networks_hypernet.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules import extra_networks, shared
|
2 |
+
from modules.hypernetworks import hypernetwork
|
3 |
+
|
4 |
+
|
5 |
+
class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__('hypernet')
|
8 |
+
|
9 |
+
def activate(self, p, params_list):
|
10 |
+
additional = shared.opts.sd_hypernetwork
|
11 |
+
|
12 |
+
if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional):
|
13 |
+
hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
|
14 |
+
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
|
15 |
+
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
16 |
+
|
17 |
+
names = []
|
18 |
+
multipliers = []
|
19 |
+
for params in params_list:
|
20 |
+
assert params.items
|
21 |
+
|
22 |
+
names.append(params.items[0])
|
23 |
+
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
24 |
+
|
25 |
+
hypernetwork.load_hypernetworks(names, multipliers)
|
26 |
+
|
27 |
+
def deactivate(self, p):
|
28 |
+
pass
|
modules/extras.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import shutil
|
4 |
+
import json
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import tqdm
|
9 |
+
|
10 |
+
from modules import shared, images, sd_models, sd_vae, sd_models_config
|
11 |
+
from modules.ui_common import plaintext_to_html
|
12 |
+
import gradio as gr
|
13 |
+
import safetensors.torch
|
14 |
+
|
15 |
+
|
16 |
+
def run_pnginfo(image):
|
17 |
+
if image is None:
|
18 |
+
return '', '', ''
|
19 |
+
|
20 |
+
geninfo, items = images.read_info_from_image(image)
|
21 |
+
items = {**{'parameters': geninfo}, **items}
|
22 |
+
|
23 |
+
info = ''
|
24 |
+
for key, text in items.items():
|
25 |
+
info += f"""
|
26 |
+
<div>
|
27 |
+
<p><b>{plaintext_to_html(str(key))}</b></p>
|
28 |
+
<p>{plaintext_to_html(str(text))}</p>
|
29 |
+
</div>
|
30 |
+
""".strip()+"\n"
|
31 |
+
|
32 |
+
if len(info) == 0:
|
33 |
+
message = "Nothing found in the image."
|
34 |
+
info = f"<div><p>{message}<p></div>"
|
35 |
+
|
36 |
+
return '', geninfo, info
|
37 |
+
|
38 |
+
|
39 |
+
def create_config(ckpt_result, config_source, a, b, c):
|
40 |
+
def config(x):
|
41 |
+
res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
|
42 |
+
return res if res != shared.sd_default_config else None
|
43 |
+
|
44 |
+
if config_source == 0:
|
45 |
+
cfg = config(a) or config(b) or config(c)
|
46 |
+
elif config_source == 1:
|
47 |
+
cfg = config(b)
|
48 |
+
elif config_source == 2:
|
49 |
+
cfg = config(c)
|
50 |
+
else:
|
51 |
+
cfg = None
|
52 |
+
|
53 |
+
if cfg is None:
|
54 |
+
return
|
55 |
+
|
56 |
+
filename, _ = os.path.splitext(ckpt_result)
|
57 |
+
checkpoint_filename = filename + ".yaml"
|
58 |
+
|
59 |
+
print("Copying config:")
|
60 |
+
print(" from:", cfg)
|
61 |
+
print(" to:", checkpoint_filename)
|
62 |
+
shutil.copyfile(cfg, checkpoint_filename)
|
63 |
+
|
64 |
+
|
65 |
+
checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
66 |
+
|
67 |
+
|
68 |
+
def to_half(tensor, enable):
|
69 |
+
if enable and tensor.dtype == torch.float:
|
70 |
+
return tensor.half()
|
71 |
+
|
72 |
+
return tensor
|
73 |
+
|
74 |
+
|
75 |
+
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
76 |
+
shared.state.begin(job="model-merge")
|
77 |
+
|
78 |
+
def fail(message):
|
79 |
+
shared.state.textinfo = message
|
80 |
+
shared.state.end()
|
81 |
+
return [*[gr.update() for _ in range(4)], message]
|
82 |
+
|
83 |
+
def weighted_sum(theta0, theta1, alpha):
|
84 |
+
return ((1 - alpha) * theta0) + (alpha * theta1)
|
85 |
+
|
86 |
+
def get_difference(theta1, theta2):
|
87 |
+
return theta1 - theta2
|
88 |
+
|
89 |
+
def add_difference(theta0, theta1_2_diff, alpha):
|
90 |
+
return theta0 + (alpha * theta1_2_diff)
|
91 |
+
|
92 |
+
def filename_weighted_sum():
|
93 |
+
a = primary_model_info.model_name
|
94 |
+
b = secondary_model_info.model_name
|
95 |
+
Ma = round(1 - multiplier, 2)
|
96 |
+
Mb = round(multiplier, 2)
|
97 |
+
|
98 |
+
return f"{Ma}({a}) + {Mb}({b})"
|
99 |
+
|
100 |
+
def filename_add_difference():
|
101 |
+
a = primary_model_info.model_name
|
102 |
+
b = secondary_model_info.model_name
|
103 |
+
c = tertiary_model_info.model_name
|
104 |
+
M = round(multiplier, 2)
|
105 |
+
|
106 |
+
return f"{a} + {M}({b} - {c})"
|
107 |
+
|
108 |
+
def filename_nothing():
|
109 |
+
return primary_model_info.model_name
|
110 |
+
|
111 |
+
theta_funcs = {
|
112 |
+
"Weighted sum": (filename_weighted_sum, None, weighted_sum),
|
113 |
+
"Add difference": (filename_add_difference, get_difference, add_difference),
|
114 |
+
"No interpolation": (filename_nothing, None, None),
|
115 |
+
}
|
116 |
+
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
|
117 |
+
shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
|
118 |
+
|
119 |
+
if not primary_model_name:
|
120 |
+
return fail("Failed: Merging requires a primary model.")
|
121 |
+
|
122 |
+
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
123 |
+
|
124 |
+
if theta_func2 and not secondary_model_name:
|
125 |
+
return fail("Failed: Merging requires a secondary model.")
|
126 |
+
|
127 |
+
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
|
128 |
+
|
129 |
+
if theta_func1 and not tertiary_model_name:
|
130 |
+
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
|
131 |
+
|
132 |
+
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
|
133 |
+
|
134 |
+
result_is_inpainting_model = False
|
135 |
+
result_is_instruct_pix2pix_model = False
|
136 |
+
|
137 |
+
if theta_func2:
|
138 |
+
shared.state.textinfo = "Loading B"
|
139 |
+
print(f"Loading {secondary_model_info.filename}...")
|
140 |
+
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
141 |
+
else:
|
142 |
+
theta_1 = None
|
143 |
+
|
144 |
+
if theta_func1:
|
145 |
+
shared.state.textinfo = "Loading C"
|
146 |
+
print(f"Loading {tertiary_model_info.filename}...")
|
147 |
+
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
148 |
+
|
149 |
+
shared.state.textinfo = 'Merging B and C'
|
150 |
+
shared.state.sampling_steps = len(theta_1.keys())
|
151 |
+
for key in tqdm.tqdm(theta_1.keys()):
|
152 |
+
if key in checkpoint_dict_skip_on_merge:
|
153 |
+
continue
|
154 |
+
|
155 |
+
if 'model' in key:
|
156 |
+
if key in theta_2:
|
157 |
+
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
158 |
+
theta_1[key] = theta_func1(theta_1[key], t2)
|
159 |
+
else:
|
160 |
+
theta_1[key] = torch.zeros_like(theta_1[key])
|
161 |
+
|
162 |
+
shared.state.sampling_step += 1
|
163 |
+
del theta_2
|
164 |
+
|
165 |
+
shared.state.nextjob()
|
166 |
+
|
167 |
+
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
168 |
+
print(f"Loading {primary_model_info.filename}...")
|
169 |
+
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
170 |
+
|
171 |
+
print("Merging...")
|
172 |
+
shared.state.textinfo = 'Merging A and B'
|
173 |
+
shared.state.sampling_steps = len(theta_0.keys())
|
174 |
+
for key in tqdm.tqdm(theta_0.keys()):
|
175 |
+
if theta_1 and 'model' in key and key in theta_1:
|
176 |
+
|
177 |
+
if key in checkpoint_dict_skip_on_merge:
|
178 |
+
continue
|
179 |
+
|
180 |
+
a = theta_0[key]
|
181 |
+
b = theta_1[key]
|
182 |
+
|
183 |
+
# this enables merging an inpainting model (A) with another one (B);
|
184 |
+
# where normal model would have 4 channels, for latenst space, inpainting model would
|
185 |
+
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
186 |
+
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
|
187 |
+
if a.shape[1] == 4 and b.shape[1] == 9:
|
188 |
+
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
|
189 |
+
if a.shape[1] == 4 and b.shape[1] == 8:
|
190 |
+
raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
|
191 |
+
|
192 |
+
if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
|
193 |
+
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
|
194 |
+
result_is_instruct_pix2pix_model = True
|
195 |
+
else:
|
196 |
+
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
197 |
+
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
198 |
+
result_is_inpainting_model = True
|
199 |
+
else:
|
200 |
+
theta_0[key] = theta_func2(a, b, multiplier)
|
201 |
+
|
202 |
+
theta_0[key] = to_half(theta_0[key], save_as_half)
|
203 |
+
|
204 |
+
shared.state.sampling_step += 1
|
205 |
+
|
206 |
+
del theta_1
|
207 |
+
|
208 |
+
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
|
209 |
+
if bake_in_vae_filename is not None:
|
210 |
+
print(f"Baking in VAE from {bake_in_vae_filename}")
|
211 |
+
shared.state.textinfo = 'Baking in VAE'
|
212 |
+
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
|
213 |
+
|
214 |
+
for key in vae_dict.keys():
|
215 |
+
theta_0_key = 'first_stage_model.' + key
|
216 |
+
if theta_0_key in theta_0:
|
217 |
+
theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
|
218 |
+
|
219 |
+
del vae_dict
|
220 |
+
|
221 |
+
if save_as_half and not theta_func2:
|
222 |
+
for key in theta_0.keys():
|
223 |
+
theta_0[key] = to_half(theta_0[key], save_as_half)
|
224 |
+
|
225 |
+
if discard_weights:
|
226 |
+
regex = re.compile(discard_weights)
|
227 |
+
for key in list(theta_0):
|
228 |
+
if re.search(regex, key):
|
229 |
+
theta_0.pop(key, None)
|
230 |
+
|
231 |
+
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
232 |
+
|
233 |
+
filename = filename_generator() if custom_name == '' else custom_name
|
234 |
+
filename += ".inpainting" if result_is_inpainting_model else ""
|
235 |
+
filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
|
236 |
+
filename += "." + checkpoint_format
|
237 |
+
|
238 |
+
output_modelname = os.path.join(ckpt_dir, filename)
|
239 |
+
|
240 |
+
shared.state.nextjob()
|
241 |
+
shared.state.textinfo = "Saving"
|
242 |
+
print(f"Saving to {output_modelname}...")
|
243 |
+
|
244 |
+
metadata = None
|
245 |
+
|
246 |
+
if save_metadata:
|
247 |
+
metadata = {"format": "pt"}
|
248 |
+
|
249 |
+
merge_recipe = {
|
250 |
+
"type": "webui", # indicate this model was merged with webui's built-in merger
|
251 |
+
"primary_model_hash": primary_model_info.sha256,
|
252 |
+
"secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
|
253 |
+
"tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
|
254 |
+
"interp_method": interp_method,
|
255 |
+
"multiplier": multiplier,
|
256 |
+
"save_as_half": save_as_half,
|
257 |
+
"custom_name": custom_name,
|
258 |
+
"config_source": config_source,
|
259 |
+
"bake_in_vae": bake_in_vae,
|
260 |
+
"discard_weights": discard_weights,
|
261 |
+
"is_inpainting": result_is_inpainting_model,
|
262 |
+
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
263 |
+
}
|
264 |
+
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
265 |
+
|
266 |
+
sd_merge_models = {}
|
267 |
+
|
268 |
+
def add_model_metadata(checkpoint_info):
|
269 |
+
checkpoint_info.calculate_shorthash()
|
270 |
+
sd_merge_models[checkpoint_info.sha256] = {
|
271 |
+
"name": checkpoint_info.name,
|
272 |
+
"legacy_hash": checkpoint_info.hash,
|
273 |
+
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
|
274 |
+
}
|
275 |
+
|
276 |
+
sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
|
277 |
+
|
278 |
+
add_model_metadata(primary_model_info)
|
279 |
+
if secondary_model_info:
|
280 |
+
add_model_metadata(secondary_model_info)
|
281 |
+
if tertiary_model_info:
|
282 |
+
add_model_metadata(tertiary_model_info)
|
283 |
+
|
284 |
+
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
285 |
+
|
286 |
+
_, extension = os.path.splitext(output_modelname)
|
287 |
+
if extension.lower() == ".safetensors":
|
288 |
+
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
|
289 |
+
else:
|
290 |
+
torch.save(theta_0, output_modelname)
|
291 |
+
|
292 |
+
sd_models.list_models()
|
293 |
+
created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
|
294 |
+
if created_model:
|
295 |
+
created_model.calculate_shorthash()
|
296 |
+
|
297 |
+
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
298 |
+
|
299 |
+
print(f"Checkpoint saved to {output_modelname}.")
|
300 |
+
shared.state.textinfo = "Checkpoint saved"
|
301 |
+
shared.state.end()
|
302 |
+
|
303 |
+
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
modules/face_restoration.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules import shared
|
2 |
+
|
3 |
+
|
4 |
+
class FaceRestoration:
|
5 |
+
def name(self):
|
6 |
+
return "None"
|
7 |
+
|
8 |
+
def restore(self, np_image):
|
9 |
+
return np_image
|
10 |
+
|
11 |
+
|
12 |
+
def restore_faces(np_image):
|
13 |
+
face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
|
14 |
+
if len(face_restorers) == 0:
|
15 |
+
return np_image
|
16 |
+
|
17 |
+
face_restorer = face_restorers[0]
|
18 |
+
|
19 |
+
return face_restorer.restore(np_image)
|
modules/generation_parameters_copypaste.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
from modules.paths import data_path
|
9 |
+
from modules import shared, ui_tempdir, script_callbacks
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
13 |
+
re_param = re.compile(re_param_code)
|
14 |
+
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
15 |
+
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
16 |
+
type_of_gr_update = type(gr.update())
|
17 |
+
|
18 |
+
paste_fields = {}
|
19 |
+
registered_param_bindings = []
|
20 |
+
|
21 |
+
|
22 |
+
class ParamBinding:
|
23 |
+
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
24 |
+
self.paste_button = paste_button
|
25 |
+
self.tabname = tabname
|
26 |
+
self.source_text_component = source_text_component
|
27 |
+
self.source_image_component = source_image_component
|
28 |
+
self.source_tabname = source_tabname
|
29 |
+
self.override_settings_component = override_settings_component
|
30 |
+
self.paste_field_names = paste_field_names or []
|
31 |
+
|
32 |
+
|
33 |
+
def reset():
|
34 |
+
paste_fields.clear()
|
35 |
+
|
36 |
+
|
37 |
+
def quote(text):
|
38 |
+
if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
|
39 |
+
return text
|
40 |
+
|
41 |
+
return json.dumps(text, ensure_ascii=False)
|
42 |
+
|
43 |
+
|
44 |
+
def unquote(text):
|
45 |
+
if len(text) == 0 or text[0] != '"' or text[-1] != '"':
|
46 |
+
return text
|
47 |
+
|
48 |
+
try:
|
49 |
+
return json.loads(text)
|
50 |
+
except Exception:
|
51 |
+
return text
|
52 |
+
|
53 |
+
|
54 |
+
def image_from_url_text(filedata):
|
55 |
+
if filedata is None:
|
56 |
+
return None
|
57 |
+
|
58 |
+
if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
|
59 |
+
filedata = filedata[0]
|
60 |
+
|
61 |
+
if type(filedata) == dict and filedata.get("is_file", False):
|
62 |
+
filename = filedata["name"]
|
63 |
+
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
64 |
+
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
65 |
+
|
66 |
+
filename = filename.rsplit('?', 1)[0]
|
67 |
+
return Image.open(filename)
|
68 |
+
|
69 |
+
if type(filedata) == list:
|
70 |
+
if len(filedata) == 0:
|
71 |
+
return None
|
72 |
+
|
73 |
+
filedata = filedata[0]
|
74 |
+
|
75 |
+
if filedata.startswith("data:image/png;base64,"):
|
76 |
+
filedata = filedata[len("data:image/png;base64,"):]
|
77 |
+
|
78 |
+
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
79 |
+
image = Image.open(io.BytesIO(filedata))
|
80 |
+
return image
|
81 |
+
|
82 |
+
|
83 |
+
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
|
84 |
+
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
|
85 |
+
|
86 |
+
# backwards compatibility for existing extensions
|
87 |
+
import modules.ui
|
88 |
+
if tabname == 'txt2img':
|
89 |
+
modules.ui.txt2img_paste_fields = fields
|
90 |
+
elif tabname == 'img2img':
|
91 |
+
modules.ui.img2img_paste_fields = fields
|
92 |
+
|
93 |
+
|
94 |
+
def create_buttons(tabs_list):
|
95 |
+
buttons = {}
|
96 |
+
for tab in tabs_list:
|
97 |
+
buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
|
98 |
+
return buttons
|
99 |
+
|
100 |
+
|
101 |
+
def bind_buttons(buttons, send_image, send_generate_info):
|
102 |
+
"""old function for backwards compatibility; do not use this, use register_paste_params_button"""
|
103 |
+
for tabname, button in buttons.items():
|
104 |
+
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
|
105 |
+
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
|
106 |
+
|
107 |
+
register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
|
108 |
+
|
109 |
+
|
110 |
+
def register_paste_params_button(binding: ParamBinding):
|
111 |
+
registered_param_bindings.append(binding)
|
112 |
+
|
113 |
+
|
114 |
+
def connect_paste_params_buttons():
|
115 |
+
binding: ParamBinding
|
116 |
+
for binding in registered_param_bindings:
|
117 |
+
destination_image_component = paste_fields[binding.tabname]["init_img"]
|
118 |
+
fields = paste_fields[binding.tabname]["fields"]
|
119 |
+
override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
|
120 |
+
|
121 |
+
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
122 |
+
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
123 |
+
|
124 |
+
if binding.source_image_component and destination_image_component:
|
125 |
+
if isinstance(binding.source_image_component, gr.Gallery):
|
126 |
+
func = send_image_and_dimensions if destination_width_component else image_from_url_text
|
127 |
+
jsfunc = "extract_image_from_gallery"
|
128 |
+
else:
|
129 |
+
func = send_image_and_dimensions if destination_width_component else lambda x: x
|
130 |
+
jsfunc = None
|
131 |
+
|
132 |
+
binding.paste_button.click(
|
133 |
+
fn=func,
|
134 |
+
_js=jsfunc,
|
135 |
+
inputs=[binding.source_image_component],
|
136 |
+
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
137 |
+
show_progress=False,
|
138 |
+
)
|
139 |
+
|
140 |
+
if binding.source_text_component is not None and fields is not None:
|
141 |
+
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
|
142 |
+
|
143 |
+
if binding.source_tabname is not None and fields is not None:
|
144 |
+
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
|
145 |
+
binding.paste_button.click(
|
146 |
+
fn=lambda *x: x,
|
147 |
+
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
148 |
+
outputs=[field for field, name in fields if name in paste_field_names],
|
149 |
+
show_progress=False,
|
150 |
+
)
|
151 |
+
|
152 |
+
binding.paste_button.click(
|
153 |
+
fn=None,
|
154 |
+
_js=f"switch_to_{binding.tabname}",
|
155 |
+
inputs=None,
|
156 |
+
outputs=None,
|
157 |
+
show_progress=False,
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
def send_image_and_dimensions(x):
|
162 |
+
if isinstance(x, Image.Image):
|
163 |
+
img = x
|
164 |
+
else:
|
165 |
+
img = image_from_url_text(x)
|
166 |
+
|
167 |
+
if shared.opts.send_size and isinstance(img, Image.Image):
|
168 |
+
w = img.width
|
169 |
+
h = img.height
|
170 |
+
else:
|
171 |
+
w = gr.update()
|
172 |
+
h = gr.update()
|
173 |
+
|
174 |
+
return img, w, h
|
175 |
+
|
176 |
+
|
177 |
+
def restore_old_hires_fix_params(res):
|
178 |
+
"""for infotexts that specify old First pass size parameter, convert it into
|
179 |
+
width, height, and hr scale"""
|
180 |
+
|
181 |
+
firstpass_width = res.get('First pass size-1', None)
|
182 |
+
firstpass_height = res.get('First pass size-2', None)
|
183 |
+
|
184 |
+
if shared.opts.use_old_hires_fix_width_height:
|
185 |
+
hires_width = int(res.get("Hires resize-1", 0))
|
186 |
+
hires_height = int(res.get("Hires resize-2", 0))
|
187 |
+
|
188 |
+
if hires_width and hires_height:
|
189 |
+
res['Size-1'] = hires_width
|
190 |
+
res['Size-2'] = hires_height
|
191 |
+
return
|
192 |
+
|
193 |
+
if firstpass_width is None or firstpass_height is None:
|
194 |
+
return
|
195 |
+
|
196 |
+
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
|
197 |
+
width = int(res.get("Size-1", 512))
|
198 |
+
height = int(res.get("Size-2", 512))
|
199 |
+
|
200 |
+
if firstpass_width == 0 or firstpass_height == 0:
|
201 |
+
from modules import processing
|
202 |
+
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
|
203 |
+
|
204 |
+
res['Size-1'] = firstpass_width
|
205 |
+
res['Size-2'] = firstpass_height
|
206 |
+
res['Hires resize-1'] = width
|
207 |
+
res['Hires resize-2'] = height
|
208 |
+
|
209 |
+
|
210 |
+
def parse_generation_parameters(x: str):
|
211 |
+
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
212 |
+
```
|
213 |
+
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
|
214 |
+
Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
|
215 |
+
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
|
216 |
+
```
|
217 |
+
|
218 |
+
returns a dict with field values
|
219 |
+
"""
|
220 |
+
|
221 |
+
res = {}
|
222 |
+
|
223 |
+
prompt = ""
|
224 |
+
negative_prompt = ""
|
225 |
+
|
226 |
+
done_with_prompt = False
|
227 |
+
|
228 |
+
*lines, lastline = x.strip().split("\n")
|
229 |
+
if len(re_param.findall(lastline)) < 3:
|
230 |
+
lines.append(lastline)
|
231 |
+
lastline = ''
|
232 |
+
|
233 |
+
for line in lines:
|
234 |
+
line = line.strip()
|
235 |
+
if line.startswith("Negative prompt:"):
|
236 |
+
done_with_prompt = True
|
237 |
+
line = line[16:].strip()
|
238 |
+
if done_with_prompt:
|
239 |
+
negative_prompt += ("" if negative_prompt == "" else "\n") + line
|
240 |
+
else:
|
241 |
+
prompt += ("" if prompt == "" else "\n") + line
|
242 |
+
|
243 |
+
if shared.opts.infotext_styles != "Ignore":
|
244 |
+
found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
|
245 |
+
|
246 |
+
if shared.opts.infotext_styles == "Apply":
|
247 |
+
res["Styles array"] = found_styles
|
248 |
+
elif shared.opts.infotext_styles == "Apply if any" and found_styles:
|
249 |
+
res["Styles array"] = found_styles
|
250 |
+
|
251 |
+
res["Prompt"] = prompt
|
252 |
+
res["Negative prompt"] = negative_prompt
|
253 |
+
|
254 |
+
for k, v in re_param.findall(lastline):
|
255 |
+
try:
|
256 |
+
if v[0] == '"' and v[-1] == '"':
|
257 |
+
v = unquote(v)
|
258 |
+
|
259 |
+
m = re_imagesize.match(v)
|
260 |
+
if m is not None:
|
261 |
+
res[f"{k}-1"] = m.group(1)
|
262 |
+
res[f"{k}-2"] = m.group(2)
|
263 |
+
else:
|
264 |
+
res[k] = v
|
265 |
+
except Exception:
|
266 |
+
print(f"Error parsing \"{k}: {v}\"")
|
267 |
+
|
268 |
+
# Missing CLIP skip means it was set to 1 (the default)
|
269 |
+
if "Clip skip" not in res:
|
270 |
+
res["Clip skip"] = "1"
|
271 |
+
|
272 |
+
hypernet = res.get("Hypernet", None)
|
273 |
+
if hypernet is not None:
|
274 |
+
res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
|
275 |
+
|
276 |
+
if "Hires resize-1" not in res:
|
277 |
+
res["Hires resize-1"] = 0
|
278 |
+
res["Hires resize-2"] = 0
|
279 |
+
|
280 |
+
if "Hires sampler" not in res:
|
281 |
+
res["Hires sampler"] = "Use same sampler"
|
282 |
+
|
283 |
+
if "Hires prompt" not in res:
|
284 |
+
res["Hires prompt"] = ""
|
285 |
+
|
286 |
+
if "Hires negative prompt" not in res:
|
287 |
+
res["Hires negative prompt"] = ""
|
288 |
+
|
289 |
+
restore_old_hires_fix_params(res)
|
290 |
+
|
291 |
+
# Missing RNG means the default was set, which is GPU RNG
|
292 |
+
if "RNG" not in res:
|
293 |
+
res["RNG"] = "GPU"
|
294 |
+
|
295 |
+
if "Schedule type" not in res:
|
296 |
+
res["Schedule type"] = "Automatic"
|
297 |
+
|
298 |
+
if "Schedule max sigma" not in res:
|
299 |
+
res["Schedule max sigma"] = 0
|
300 |
+
|
301 |
+
if "Schedule min sigma" not in res:
|
302 |
+
res["Schedule min sigma"] = 0
|
303 |
+
|
304 |
+
if "Schedule rho" not in res:
|
305 |
+
res["Schedule rho"] = 0
|
306 |
+
|
307 |
+
return res
|
308 |
+
|
309 |
+
|
310 |
+
infotext_to_setting_name_mapping = [
|
311 |
+
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
312 |
+
('Conditional mask weight', 'inpainting_mask_weight'),
|
313 |
+
('Model hash', 'sd_model_checkpoint'),
|
314 |
+
('ENSD', 'eta_noise_seed_delta'),
|
315 |
+
('Schedule type', 'k_sched_type'),
|
316 |
+
('Schedule max sigma', 'sigma_max'),
|
317 |
+
('Schedule min sigma', 'sigma_min'),
|
318 |
+
('Schedule rho', 'rho'),
|
319 |
+
('Noise multiplier', 'initial_noise_multiplier'),
|
320 |
+
('Eta', 'eta_ancestral'),
|
321 |
+
('Eta DDIM', 'eta_ddim'),
|
322 |
+
('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
|
323 |
+
('UniPC variant', 'uni_pc_variant'),
|
324 |
+
('UniPC skip type', 'uni_pc_skip_type'),
|
325 |
+
('UniPC order', 'uni_pc_order'),
|
326 |
+
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
327 |
+
('Token merging ratio', 'token_merging_ratio'),
|
328 |
+
('Token merging ratio hr', 'token_merging_ratio_hr'),
|
329 |
+
('RNG', 'randn_source'),
|
330 |
+
('NGMS', 's_min_uncond'),
|
331 |
+
('Pad conds', 'pad_cond_uncond'),
|
332 |
+
]
|
333 |
+
|
334 |
+
|
335 |
+
def create_override_settings_dict(text_pairs):
|
336 |
+
"""creates processing's override_settings parameters from gradio's multiselect
|
337 |
+
|
338 |
+
Example input:
|
339 |
+
['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
|
340 |
+
|
341 |
+
Example output:
|
342 |
+
{'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
|
343 |
+
"""
|
344 |
+
|
345 |
+
res = {}
|
346 |
+
|
347 |
+
params = {}
|
348 |
+
for pair in text_pairs:
|
349 |
+
k, v = pair.split(":", maxsplit=1)
|
350 |
+
|
351 |
+
params[k] = v.strip()
|
352 |
+
|
353 |
+
for param_name, setting_name in infotext_to_setting_name_mapping:
|
354 |
+
value = params.get(param_name, None)
|
355 |
+
|
356 |
+
if value is None:
|
357 |
+
continue
|
358 |
+
|
359 |
+
res[setting_name] = shared.opts.cast_value(setting_name, value)
|
360 |
+
|
361 |
+
return res
|
362 |
+
|
363 |
+
|
364 |
+
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
365 |
+
def paste_func(prompt):
|
366 |
+
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
367 |
+
filename = os.path.join(data_path, "params.txt")
|
368 |
+
if os.path.exists(filename):
|
369 |
+
with open(filename, "r", encoding="utf8") as file:
|
370 |
+
prompt = file.read()
|
371 |
+
|
372 |
+
params = parse_generation_parameters(prompt)
|
373 |
+
script_callbacks.infotext_pasted_callback(prompt, params)
|
374 |
+
res = []
|
375 |
+
|
376 |
+
for output, key in paste_fields:
|
377 |
+
if callable(key):
|
378 |
+
v = key(params)
|
379 |
+
else:
|
380 |
+
v = params.get(key, None)
|
381 |
+
|
382 |
+
if v is None:
|
383 |
+
res.append(gr.update())
|
384 |
+
elif isinstance(v, type_of_gr_update):
|
385 |
+
res.append(v)
|
386 |
+
else:
|
387 |
+
try:
|
388 |
+
valtype = type(output.value)
|
389 |
+
|
390 |
+
if valtype == bool and v == "False":
|
391 |
+
val = False
|
392 |
+
else:
|
393 |
+
val = valtype(v)
|
394 |
+
|
395 |
+
res.append(gr.update(value=val))
|
396 |
+
except Exception:
|
397 |
+
res.append(gr.update())
|
398 |
+
|
399 |
+
return res
|
400 |
+
|
401 |
+
if override_settings_component is not None:
|
402 |
+
def paste_settings(params):
|
403 |
+
vals = {}
|
404 |
+
|
405 |
+
for param_name, setting_name in infotext_to_setting_name_mapping:
|
406 |
+
v = params.get(param_name, None)
|
407 |
+
if v is None:
|
408 |
+
continue
|
409 |
+
|
410 |
+
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
|
411 |
+
continue
|
412 |
+
|
413 |
+
v = shared.opts.cast_value(setting_name, v)
|
414 |
+
current_value = getattr(shared.opts, setting_name, None)
|
415 |
+
|
416 |
+
if v == current_value:
|
417 |
+
continue
|
418 |
+
|
419 |
+
vals[param_name] = v
|
420 |
+
|
421 |
+
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
|
422 |
+
|
423 |
+
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
|
424 |
+
|
425 |
+
paste_fields = paste_fields + [(override_settings_component, paste_settings)]
|
426 |
+
|
427 |
+
button.click(
|
428 |
+
fn=paste_func,
|
429 |
+
inputs=[input_comp],
|
430 |
+
outputs=[x[0] for x in paste_fields],
|
431 |
+
show_progress=False,
|
432 |
+
)
|
433 |
+
button.click(
|
434 |
+
fn=None,
|
435 |
+
_js=f"recalculate_prompts_{tabname}",
|
436 |
+
inputs=[],
|
437 |
+
outputs=[],
|
438 |
+
show_progress=False,
|
439 |
+
)
|
modules/gfpgan_model.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import facexlib
|
4 |
+
import gfpgan
|
5 |
+
|
6 |
+
import modules.face_restoration
|
7 |
+
from modules import paths, shared, devices, modelloader, errors
|
8 |
+
|
9 |
+
model_dir = "GFPGAN"
|
10 |
+
user_path = None
|
11 |
+
model_path = os.path.join(paths.models_path, model_dir)
|
12 |
+
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
13 |
+
have_gfpgan = False
|
14 |
+
loaded_gfpgan_model = None
|
15 |
+
|
16 |
+
|
17 |
+
def gfpgann():
|
18 |
+
global loaded_gfpgan_model
|
19 |
+
global model_path
|
20 |
+
if loaded_gfpgan_model is not None:
|
21 |
+
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
22 |
+
return loaded_gfpgan_model
|
23 |
+
|
24 |
+
if gfpgan_constructor is None:
|
25 |
+
return None
|
26 |
+
|
27 |
+
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
28 |
+
if len(models) == 1 and models[0].startswith("http"):
|
29 |
+
model_file = models[0]
|
30 |
+
elif len(models) != 0:
|
31 |
+
latest_file = max(models, key=os.path.getctime)
|
32 |
+
model_file = latest_file
|
33 |
+
else:
|
34 |
+
print("Unable to load gfpgan model!")
|
35 |
+
return None
|
36 |
+
if hasattr(facexlib.detection.retinaface, 'device'):
|
37 |
+
facexlib.detection.retinaface.device = devices.device_gfpgan
|
38 |
+
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
39 |
+
loaded_gfpgan_model = model
|
40 |
+
|
41 |
+
return model
|
42 |
+
|
43 |
+
|
44 |
+
def send_model_to(model, device):
|
45 |
+
model.gfpgan.to(device)
|
46 |
+
model.face_helper.face_det.to(device)
|
47 |
+
model.face_helper.face_parse.to(device)
|
48 |
+
|
49 |
+
|
50 |
+
def gfpgan_fix_faces(np_image):
|
51 |
+
model = gfpgann()
|
52 |
+
if model is None:
|
53 |
+
return np_image
|
54 |
+
|
55 |
+
send_model_to(model, devices.device_gfpgan)
|
56 |
+
|
57 |
+
np_image_bgr = np_image[:, :, ::-1]
|
58 |
+
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
59 |
+
np_image = gfpgan_output_bgr[:, :, ::-1]
|
60 |
+
|
61 |
+
model.face_helper.clean_all()
|
62 |
+
|
63 |
+
if shared.opts.face_restoration_unload:
|
64 |
+
send_model_to(model, devices.cpu)
|
65 |
+
|
66 |
+
return np_image
|
67 |
+
|
68 |
+
|
69 |
+
gfpgan_constructor = None
|
70 |
+
|
71 |
+
|
72 |
+
def setup_model(dirname):
|
73 |
+
try:
|
74 |
+
os.makedirs(model_path, exist_ok=True)
|
75 |
+
from gfpgan import GFPGANer
|
76 |
+
from facexlib import detection, parsing # noqa: F401
|
77 |
+
global user_path
|
78 |
+
global have_gfpgan
|
79 |
+
global gfpgan_constructor
|
80 |
+
|
81 |
+
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
82 |
+
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
83 |
+
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
84 |
+
|
85 |
+
def my_load_file_from_url(**kwargs):
|
86 |
+
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
|
87 |
+
|
88 |
+
def facex_load_file_from_url(**kwargs):
|
89 |
+
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
|
90 |
+
|
91 |
+
def facex_load_file_from_url2(**kwargs):
|
92 |
+
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
|
93 |
+
|
94 |
+
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
95 |
+
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
96 |
+
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
97 |
+
user_path = dirname
|
98 |
+
have_gfpgan = True
|
99 |
+
gfpgan_constructor = GFPGANer
|
100 |
+
|
101 |
+
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
102 |
+
def name(self):
|
103 |
+
return "GFPGAN"
|
104 |
+
|
105 |
+
def restore(self, np_image):
|
106 |
+
return gfpgan_fix_faces(np_image)
|
107 |
+
|
108 |
+
shared.face_restorers.append(FaceRestorerGFPGAN())
|
109 |
+
except Exception:
|
110 |
+
errors.report("Error setting up GFPGAN", exc_info=True)
|
modules/gitpython_hack.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import io
|
4 |
+
import subprocess
|
5 |
+
|
6 |
+
import git
|
7 |
+
|
8 |
+
|
9 |
+
class Git(git.Git):
|
10 |
+
"""
|
11 |
+
Git subclassed to never use persistent processes.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):
|
15 |
+
raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})")
|
16 |
+
|
17 |
+
def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
|
18 |
+
ret = subprocess.check_output(
|
19 |
+
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"],
|
20 |
+
input=self._prepare_ref(ref),
|
21 |
+
cwd=self._working_dir,
|
22 |
+
timeout=2,
|
23 |
+
)
|
24 |
+
return self._parse_object_header(ret)
|
25 |
+
|
26 |
+
def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
|
27 |
+
# Not really streaming, per se; this buffers the entire object in memory.
|
28 |
+
# Shouldn't be a problem for our use case, since we're only using this for
|
29 |
+
# object headers (commit objects).
|
30 |
+
ret = subprocess.check_output(
|
31 |
+
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"],
|
32 |
+
input=self._prepare_ref(ref),
|
33 |
+
cwd=self._working_dir,
|
34 |
+
timeout=30,
|
35 |
+
)
|
36 |
+
bio = io.BytesIO(ret)
|
37 |
+
hexsha, typename, size = self._parse_object_header(bio.readline())
|
38 |
+
return (hexsha, typename, size, self.CatFileContentStream(size, bio))
|
39 |
+
|
40 |
+
|
41 |
+
class Repo(git.Repo):
|
42 |
+
GitCommandWrapperType = Git
|
modules/hashes.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os.path
|
3 |
+
|
4 |
+
from modules import shared
|
5 |
+
import modules.cache
|
6 |
+
|
7 |
+
dump_cache = modules.cache.dump_cache
|
8 |
+
cache = modules.cache.cache
|
9 |
+
|
10 |
+
|
11 |
+
def calculate_sha256(filename):
|
12 |
+
hash_sha256 = hashlib.sha256()
|
13 |
+
blksize = 1024 * 1024
|
14 |
+
|
15 |
+
with open(filename, "rb") as f:
|
16 |
+
for chunk in iter(lambda: f.read(blksize), b""):
|
17 |
+
hash_sha256.update(chunk)
|
18 |
+
|
19 |
+
return hash_sha256.hexdigest()
|
20 |
+
|
21 |
+
|
22 |
+
def sha256_from_cache(filename, title, use_addnet_hash=False):
|
23 |
+
hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
|
24 |
+
ondisk_mtime = os.path.getmtime(filename)
|
25 |
+
|
26 |
+
if title not in hashes:
|
27 |
+
return None
|
28 |
+
|
29 |
+
cached_sha256 = hashes[title].get("sha256", None)
|
30 |
+
cached_mtime = hashes[title].get("mtime", 0)
|
31 |
+
|
32 |
+
if ondisk_mtime > cached_mtime or cached_sha256 is None:
|
33 |
+
return None
|
34 |
+
|
35 |
+
return cached_sha256
|
36 |
+
|
37 |
+
|
38 |
+
def sha256(filename, title, use_addnet_hash=False):
|
39 |
+
hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
|
40 |
+
|
41 |
+
sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
|
42 |
+
if sha256_value is not None:
|
43 |
+
return sha256_value
|
44 |
+
|
45 |
+
if shared.cmd_opts.no_hashing:
|
46 |
+
return None
|
47 |
+
|
48 |
+
print(f"Calculating sha256 for {filename}: ", end='')
|
49 |
+
if use_addnet_hash:
|
50 |
+
with open(filename, "rb") as file:
|
51 |
+
sha256_value = addnet_hash_safetensors(file)
|
52 |
+
else:
|
53 |
+
sha256_value = calculate_sha256(filename)
|
54 |
+
print(f"{sha256_value}")
|
55 |
+
|
56 |
+
hashes[title] = {
|
57 |
+
"mtime": os.path.getmtime(filename),
|
58 |
+
"sha256": sha256_value,
|
59 |
+
}
|
60 |
+
|
61 |
+
dump_cache()
|
62 |
+
|
63 |
+
return sha256_value
|
64 |
+
|
65 |
+
|
66 |
+
def addnet_hash_safetensors(b):
|
67 |
+
"""kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
|
68 |
+
hash_sha256 = hashlib.sha256()
|
69 |
+
blksize = 1024 * 1024
|
70 |
+
|
71 |
+
b.seek(0)
|
72 |
+
header = b.read(8)
|
73 |
+
n = int.from_bytes(header, "little")
|
74 |
+
|
75 |
+
offset = n + 8
|
76 |
+
b.seek(offset)
|
77 |
+
for chunk in iter(lambda: b.read(blksize), b""):
|
78 |
+
hash_sha256.update(chunk)
|
79 |
+
|
80 |
+
return hash_sha256.hexdigest()
|
81 |
+
|
modules/hypernetworks/__pycache__/hypernetwork.cpython-310.pyc
ADDED
Binary file (21.5 kB). View file
|
|
modules/hypernetworks/__pycache__/ui.cpython-310.pyc
ADDED
Binary file (1.69 kB). View file
|
|
modules/hypernetworks/hypernetwork.py
ADDED
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import glob
|
3 |
+
import html
|
4 |
+
import os
|
5 |
+
import inspect
|
6 |
+
from contextlib import closing
|
7 |
+
|
8 |
+
import modules.textual_inversion.dataset
|
9 |
+
import torch
|
10 |
+
import tqdm
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
from ldm.util import default
|
13 |
+
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
14 |
+
from modules.textual_inversion import textual_inversion, logging
|
15 |
+
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
16 |
+
from torch import einsum
|
17 |
+
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
18 |
+
|
19 |
+
from collections import deque
|
20 |
+
from statistics import stdev, mean
|
21 |
+
|
22 |
+
|
23 |
+
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
24 |
+
|
25 |
+
class HypernetworkModule(torch.nn.Module):
|
26 |
+
activation_dict = {
|
27 |
+
"linear": torch.nn.Identity,
|
28 |
+
"relu": torch.nn.ReLU,
|
29 |
+
"leakyrelu": torch.nn.LeakyReLU,
|
30 |
+
"elu": torch.nn.ELU,
|
31 |
+
"swish": torch.nn.Hardswish,
|
32 |
+
"tanh": torch.nn.Tanh,
|
33 |
+
"sigmoid": torch.nn.Sigmoid,
|
34 |
+
}
|
35 |
+
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
36 |
+
|
37 |
+
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
38 |
+
add_layer_norm=False, activate_output=False, dropout_structure=None):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.multiplier = 1.0
|
42 |
+
|
43 |
+
assert layer_structure is not None, "layer_structure must not be None"
|
44 |
+
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
45 |
+
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
46 |
+
|
47 |
+
linears = []
|
48 |
+
for i in range(len(layer_structure) - 1):
|
49 |
+
|
50 |
+
# Add a fully-connected layer
|
51 |
+
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
52 |
+
|
53 |
+
# Add an activation func except last layer
|
54 |
+
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
|
55 |
+
pass
|
56 |
+
elif activation_func in self.activation_dict:
|
57 |
+
linears.append(self.activation_dict[activation_func]())
|
58 |
+
else:
|
59 |
+
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
60 |
+
|
61 |
+
# Add layer normalization
|
62 |
+
if add_layer_norm:
|
63 |
+
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
64 |
+
|
65 |
+
# Everything should be now parsed into dropout structure, and applied here.
|
66 |
+
# Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
|
67 |
+
if dropout_structure is not None and dropout_structure[i+1] > 0:
|
68 |
+
assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
|
69 |
+
linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
|
70 |
+
# Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
|
71 |
+
|
72 |
+
self.linear = torch.nn.Sequential(*linears)
|
73 |
+
|
74 |
+
if state_dict is not None:
|
75 |
+
self.fix_old_state_dict(state_dict)
|
76 |
+
self.load_state_dict(state_dict)
|
77 |
+
else:
|
78 |
+
for layer in self.linear:
|
79 |
+
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
80 |
+
w, b = layer.weight.data, layer.bias.data
|
81 |
+
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
|
82 |
+
normal_(w, mean=0.0, std=0.01)
|
83 |
+
normal_(b, mean=0.0, std=0)
|
84 |
+
elif weight_init == 'XavierUniform':
|
85 |
+
xavier_uniform_(w)
|
86 |
+
zeros_(b)
|
87 |
+
elif weight_init == 'XavierNormal':
|
88 |
+
xavier_normal_(w)
|
89 |
+
zeros_(b)
|
90 |
+
elif weight_init == 'KaimingUniform':
|
91 |
+
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
92 |
+
zeros_(b)
|
93 |
+
elif weight_init == 'KaimingNormal':
|
94 |
+
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
95 |
+
zeros_(b)
|
96 |
+
else:
|
97 |
+
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
98 |
+
self.to(devices.device)
|
99 |
+
|
100 |
+
def fix_old_state_dict(self, state_dict):
|
101 |
+
changes = {
|
102 |
+
'linear1.bias': 'linear.0.bias',
|
103 |
+
'linear1.weight': 'linear.0.weight',
|
104 |
+
'linear2.bias': 'linear.1.bias',
|
105 |
+
'linear2.weight': 'linear.1.weight',
|
106 |
+
}
|
107 |
+
|
108 |
+
for fr, to in changes.items():
|
109 |
+
x = state_dict.get(fr, None)
|
110 |
+
if x is None:
|
111 |
+
continue
|
112 |
+
|
113 |
+
del state_dict[fr]
|
114 |
+
state_dict[to] = x
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
return x + self.linear(x) * (self.multiplier if not self.training else 1)
|
118 |
+
|
119 |
+
def trainables(self):
|
120 |
+
layer_structure = []
|
121 |
+
for layer in self.linear:
|
122 |
+
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
123 |
+
layer_structure += [layer.weight, layer.bias]
|
124 |
+
return layer_structure
|
125 |
+
|
126 |
+
|
127 |
+
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
128 |
+
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
129 |
+
if layer_structure is None:
|
130 |
+
layer_structure = [1, 2, 1]
|
131 |
+
if not use_dropout:
|
132 |
+
return [0] * len(layer_structure)
|
133 |
+
dropout_values = [0]
|
134 |
+
dropout_values.extend([0.3] * (len(layer_structure) - 3))
|
135 |
+
if last_layer_dropout:
|
136 |
+
dropout_values.append(0.3)
|
137 |
+
else:
|
138 |
+
dropout_values.append(0)
|
139 |
+
dropout_values.append(0)
|
140 |
+
return dropout_values
|
141 |
+
|
142 |
+
|
143 |
+
class Hypernetwork:
|
144 |
+
filename = None
|
145 |
+
name = None
|
146 |
+
|
147 |
+
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
|
148 |
+
self.filename = None
|
149 |
+
self.name = name
|
150 |
+
self.layers = {}
|
151 |
+
self.step = 0
|
152 |
+
self.sd_checkpoint = None
|
153 |
+
self.sd_checkpoint_name = None
|
154 |
+
self.layer_structure = layer_structure
|
155 |
+
self.activation_func = activation_func
|
156 |
+
self.weight_init = weight_init
|
157 |
+
self.add_layer_norm = add_layer_norm
|
158 |
+
self.use_dropout = use_dropout
|
159 |
+
self.activate_output = activate_output
|
160 |
+
self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
|
161 |
+
self.dropout_structure = kwargs.get('dropout_structure', None)
|
162 |
+
if self.dropout_structure is None:
|
163 |
+
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
164 |
+
self.optimizer_name = None
|
165 |
+
self.optimizer_state_dict = None
|
166 |
+
self.optional_info = None
|
167 |
+
|
168 |
+
for size in enable_sizes or []:
|
169 |
+
self.layers[size] = (
|
170 |
+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
171 |
+
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
172 |
+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
173 |
+
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
174 |
+
)
|
175 |
+
self.eval()
|
176 |
+
|
177 |
+
def weights(self):
|
178 |
+
res = []
|
179 |
+
for layers in self.layers.values():
|
180 |
+
for layer in layers:
|
181 |
+
res += layer.parameters()
|
182 |
+
return res
|
183 |
+
|
184 |
+
def train(self, mode=True):
|
185 |
+
for layers in self.layers.values():
|
186 |
+
for layer in layers:
|
187 |
+
layer.train(mode=mode)
|
188 |
+
for param in layer.parameters():
|
189 |
+
param.requires_grad = mode
|
190 |
+
|
191 |
+
def to(self, device):
|
192 |
+
for layers in self.layers.values():
|
193 |
+
for layer in layers:
|
194 |
+
layer.to(device)
|
195 |
+
|
196 |
+
return self
|
197 |
+
|
198 |
+
def set_multiplier(self, multiplier):
|
199 |
+
for layers in self.layers.values():
|
200 |
+
for layer in layers:
|
201 |
+
layer.multiplier = multiplier
|
202 |
+
|
203 |
+
return self
|
204 |
+
|
205 |
+
def eval(self):
|
206 |
+
for layers in self.layers.values():
|
207 |
+
for layer in layers:
|
208 |
+
layer.eval()
|
209 |
+
for param in layer.parameters():
|
210 |
+
param.requires_grad = False
|
211 |
+
|
212 |
+
def save(self, filename):
|
213 |
+
state_dict = {}
|
214 |
+
optimizer_saved_dict = {}
|
215 |
+
|
216 |
+
for k, v in self.layers.items():
|
217 |
+
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
218 |
+
|
219 |
+
state_dict['step'] = self.step
|
220 |
+
state_dict['name'] = self.name
|
221 |
+
state_dict['layer_structure'] = self.layer_structure
|
222 |
+
state_dict['activation_func'] = self.activation_func
|
223 |
+
state_dict['is_layer_norm'] = self.add_layer_norm
|
224 |
+
state_dict['weight_initialization'] = self.weight_init
|
225 |
+
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
226 |
+
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
227 |
+
state_dict['activate_output'] = self.activate_output
|
228 |
+
state_dict['use_dropout'] = self.use_dropout
|
229 |
+
state_dict['dropout_structure'] = self.dropout_structure
|
230 |
+
state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
|
231 |
+
state_dict['optional_info'] = self.optional_info if self.optional_info else None
|
232 |
+
|
233 |
+
if self.optimizer_name is not None:
|
234 |
+
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
235 |
+
|
236 |
+
torch.save(state_dict, filename)
|
237 |
+
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
238 |
+
optimizer_saved_dict['hash'] = self.shorthash()
|
239 |
+
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
240 |
+
torch.save(optimizer_saved_dict, filename + '.optim')
|
241 |
+
|
242 |
+
def load(self, filename):
|
243 |
+
self.filename = filename
|
244 |
+
if self.name is None:
|
245 |
+
self.name = os.path.splitext(os.path.basename(filename))[0]
|
246 |
+
|
247 |
+
state_dict = torch.load(filename, map_location='cpu')
|
248 |
+
|
249 |
+
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
250 |
+
self.optional_info = state_dict.get('optional_info', None)
|
251 |
+
self.activation_func = state_dict.get('activation_func', None)
|
252 |
+
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
253 |
+
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
254 |
+
self.dropout_structure = state_dict.get('dropout_structure', None)
|
255 |
+
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
256 |
+
self.activate_output = state_dict.get('activate_output', True)
|
257 |
+
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
258 |
+
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
259 |
+
if self.dropout_structure is None:
|
260 |
+
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
261 |
+
|
262 |
+
if shared.opts.print_hypernet_extra:
|
263 |
+
if self.optional_info is not None:
|
264 |
+
print(f" INFO:\n {self.optional_info}\n")
|
265 |
+
|
266 |
+
print(f" Layer structure: {self.layer_structure}")
|
267 |
+
print(f" Activation function: {self.activation_func}")
|
268 |
+
print(f" Weight initialization: {self.weight_init}")
|
269 |
+
print(f" Layer norm: {self.add_layer_norm}")
|
270 |
+
print(f" Dropout usage: {self.use_dropout}" )
|
271 |
+
print(f" Activate last layer: {self.activate_output}")
|
272 |
+
print(f" Dropout structure: {self.dropout_structure}")
|
273 |
+
|
274 |
+
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
|
275 |
+
|
276 |
+
if self.shorthash() == optimizer_saved_dict.get('hash', None):
|
277 |
+
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
278 |
+
else:
|
279 |
+
self.optimizer_state_dict = None
|
280 |
+
if self.optimizer_state_dict:
|
281 |
+
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
282 |
+
if shared.opts.print_hypernet_extra:
|
283 |
+
print("Loaded existing optimizer from checkpoint")
|
284 |
+
print(f"Optimizer name is {self.optimizer_name}")
|
285 |
+
else:
|
286 |
+
self.optimizer_name = "AdamW"
|
287 |
+
if shared.opts.print_hypernet_extra:
|
288 |
+
print("No saved optimizer exists in checkpoint")
|
289 |
+
|
290 |
+
for size, sd in state_dict.items():
|
291 |
+
if type(size) == int:
|
292 |
+
self.layers[size] = (
|
293 |
+
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
294 |
+
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
295 |
+
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
296 |
+
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
297 |
+
)
|
298 |
+
|
299 |
+
self.name = state_dict.get('name', self.name)
|
300 |
+
self.step = state_dict.get('step', 0)
|
301 |
+
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
302 |
+
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
303 |
+
self.eval()
|
304 |
+
|
305 |
+
def shorthash(self):
|
306 |
+
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
307 |
+
|
308 |
+
return sha256[0:10] if sha256 else None
|
309 |
+
|
310 |
+
|
311 |
+
def list_hypernetworks(path):
|
312 |
+
res = {}
|
313 |
+
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True), key=str.lower):
|
314 |
+
name = os.path.splitext(os.path.basename(filename))[0]
|
315 |
+
# Prevent a hypothetical "None.pt" from being listed.
|
316 |
+
if name != "None":
|
317 |
+
res[name] = filename
|
318 |
+
return res
|
319 |
+
|
320 |
+
|
321 |
+
def load_hypernetwork(name):
|
322 |
+
path = shared.hypernetworks.get(name, None)
|
323 |
+
|
324 |
+
if path is None:
|
325 |
+
return None
|
326 |
+
|
327 |
+
try:
|
328 |
+
hypernetwork = Hypernetwork()
|
329 |
+
hypernetwork.load(path)
|
330 |
+
return hypernetwork
|
331 |
+
except Exception:
|
332 |
+
errors.report(f"Error loading hypernetwork {path}", exc_info=True)
|
333 |
+
return None
|
334 |
+
|
335 |
+
|
336 |
+
def load_hypernetworks(names, multipliers=None):
|
337 |
+
already_loaded = {}
|
338 |
+
|
339 |
+
for hypernetwork in shared.loaded_hypernetworks:
|
340 |
+
if hypernetwork.name in names:
|
341 |
+
already_loaded[hypernetwork.name] = hypernetwork
|
342 |
+
|
343 |
+
shared.loaded_hypernetworks.clear()
|
344 |
+
|
345 |
+
for i, name in enumerate(names):
|
346 |
+
hypernetwork = already_loaded.get(name, None)
|
347 |
+
if hypernetwork is None:
|
348 |
+
hypernetwork = load_hypernetwork(name)
|
349 |
+
|
350 |
+
if hypernetwork is None:
|
351 |
+
continue
|
352 |
+
|
353 |
+
hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
|
354 |
+
shared.loaded_hypernetworks.append(hypernetwork)
|
355 |
+
|
356 |
+
|
357 |
+
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
358 |
+
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
359 |
+
|
360 |
+
if hypernetwork_layers is None:
|
361 |
+
return context_k, context_v
|
362 |
+
|
363 |
+
if layer is not None:
|
364 |
+
layer.hyper_k = hypernetwork_layers[0]
|
365 |
+
layer.hyper_v = hypernetwork_layers[1]
|
366 |
+
|
367 |
+
context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
|
368 |
+
context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
|
369 |
+
return context_k, context_v
|
370 |
+
|
371 |
+
|
372 |
+
def apply_hypernetworks(hypernetworks, context, layer=None):
|
373 |
+
context_k = context
|
374 |
+
context_v = context
|
375 |
+
for hypernetwork in hypernetworks:
|
376 |
+
context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
|
377 |
+
|
378 |
+
return context_k, context_v
|
379 |
+
|
380 |
+
|
381 |
+
def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
|
382 |
+
h = self.heads
|
383 |
+
|
384 |
+
q = self.to_q(x)
|
385 |
+
context = default(context, x)
|
386 |
+
|
387 |
+
context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
|
388 |
+
k = self.to_k(context_k)
|
389 |
+
v = self.to_v(context_v)
|
390 |
+
|
391 |
+
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
392 |
+
|
393 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
394 |
+
|
395 |
+
if mask is not None:
|
396 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
397 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
398 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
399 |
+
sim.masked_fill_(~mask, max_neg_value)
|
400 |
+
|
401 |
+
# attention, what we cannot get enough of
|
402 |
+
attn = sim.softmax(dim=-1)
|
403 |
+
|
404 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
405 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
406 |
+
return self.to_out(out)
|
407 |
+
|
408 |
+
|
409 |
+
def stack_conds(conds):
|
410 |
+
if len(conds) == 1:
|
411 |
+
return torch.stack(conds)
|
412 |
+
|
413 |
+
# same as in reconstruct_multicond_batch
|
414 |
+
token_count = max([x.shape[0] for x in conds])
|
415 |
+
for i in range(len(conds)):
|
416 |
+
if conds[i].shape[0] != token_count:
|
417 |
+
last_vector = conds[i][-1:]
|
418 |
+
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
|
419 |
+
conds[i] = torch.vstack([conds[i], last_vector_repeated])
|
420 |
+
|
421 |
+
return torch.stack(conds)
|
422 |
+
|
423 |
+
|
424 |
+
def statistics(data):
|
425 |
+
if len(data) < 2:
|
426 |
+
std = 0
|
427 |
+
else:
|
428 |
+
std = stdev(data)
|
429 |
+
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
|
430 |
+
recent_data = data[-32:]
|
431 |
+
if len(recent_data) < 2:
|
432 |
+
std = 0
|
433 |
+
else:
|
434 |
+
std = stdev(recent_data)
|
435 |
+
recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
|
436 |
+
return total_information, recent_information
|
437 |
+
|
438 |
+
|
439 |
+
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
440 |
+
# Remove illegal characters from name.
|
441 |
+
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
442 |
+
assert name, "Name cannot be empty!"
|
443 |
+
|
444 |
+
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
445 |
+
if not overwrite_old:
|
446 |
+
assert not os.path.exists(fn), f"file {fn} already exists"
|
447 |
+
|
448 |
+
if type(layer_structure) == str:
|
449 |
+
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
450 |
+
|
451 |
+
if use_dropout and dropout_structure and type(dropout_structure) == str:
|
452 |
+
dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
|
453 |
+
else:
|
454 |
+
dropout_structure = [0] * len(layer_structure)
|
455 |
+
|
456 |
+
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
457 |
+
name=name,
|
458 |
+
enable_sizes=[int(x) for x in enable_sizes],
|
459 |
+
layer_structure=layer_structure,
|
460 |
+
activation_func=activation_func,
|
461 |
+
weight_init=weight_init,
|
462 |
+
add_layer_norm=add_layer_norm,
|
463 |
+
use_dropout=use_dropout,
|
464 |
+
dropout_structure=dropout_structure
|
465 |
+
)
|
466 |
+
hypernet.save(fn)
|
467 |
+
|
468 |
+
shared.reload_hypernetworks()
|
469 |
+
|
470 |
+
|
471 |
+
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
472 |
+
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
473 |
+
from modules import images
|
474 |
+
|
475 |
+
save_hypernetwork_every = save_hypernetwork_every or 0
|
476 |
+
create_image_every = create_image_every or 0
|
477 |
+
template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
|
478 |
+
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
479 |
+
template_file = template_file.path
|
480 |
+
|
481 |
+
path = shared.hypernetworks.get(hypernetwork_name, None)
|
482 |
+
hypernetwork = Hypernetwork()
|
483 |
+
hypernetwork.load(path)
|
484 |
+
shared.loaded_hypernetworks = [hypernetwork]
|
485 |
+
|
486 |
+
shared.state.job = "train-hypernetwork"
|
487 |
+
shared.state.textinfo = "Initializing hypernetwork training..."
|
488 |
+
shared.state.job_count = steps
|
489 |
+
|
490 |
+
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
491 |
+
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
492 |
+
|
493 |
+
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
494 |
+
unload = shared.opts.unload_models_when_training
|
495 |
+
|
496 |
+
if save_hypernetwork_every > 0:
|
497 |
+
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
498 |
+
os.makedirs(hypernetwork_dir, exist_ok=True)
|
499 |
+
else:
|
500 |
+
hypernetwork_dir = None
|
501 |
+
|
502 |
+
if create_image_every > 0:
|
503 |
+
images_dir = os.path.join(log_directory, "images")
|
504 |
+
os.makedirs(images_dir, exist_ok=True)
|
505 |
+
else:
|
506 |
+
images_dir = None
|
507 |
+
|
508 |
+
checkpoint = sd_models.select_checkpoint()
|
509 |
+
|
510 |
+
initial_step = hypernetwork.step or 0
|
511 |
+
if initial_step >= steps:
|
512 |
+
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
513 |
+
return hypernetwork, filename
|
514 |
+
|
515 |
+
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
516 |
+
|
517 |
+
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
518 |
+
if clip_grad:
|
519 |
+
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
520 |
+
|
521 |
+
if shared.opts.training_enable_tensorboard:
|
522 |
+
tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
|
523 |
+
|
524 |
+
# dataset loading may take a while, so input validations and early returns should be done before this
|
525 |
+
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
526 |
+
|
527 |
+
pin_memory = shared.opts.pin_memory
|
528 |
+
|
529 |
+
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
|
530 |
+
|
531 |
+
if shared.opts.save_training_settings_to_txt:
|
532 |
+
saved_params = dict(
|
533 |
+
model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
|
534 |
+
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
|
535 |
+
)
|
536 |
+
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
|
537 |
+
|
538 |
+
latent_sampling_method = ds.latent_sampling_method
|
539 |
+
|
540 |
+
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
541 |
+
|
542 |
+
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
543 |
+
|
544 |
+
if unload:
|
545 |
+
shared.parallel_processing_allowed = False
|
546 |
+
shared.sd_model.cond_stage_model.to(devices.cpu)
|
547 |
+
shared.sd_model.first_stage_model.to(devices.cpu)
|
548 |
+
|
549 |
+
weights = hypernetwork.weights()
|
550 |
+
hypernetwork.train()
|
551 |
+
|
552 |
+
# Here we use optimizer from saved HN, or we can specify as UI option.
|
553 |
+
if hypernetwork.optimizer_name in optimizer_dict:
|
554 |
+
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
555 |
+
optimizer_name = hypernetwork.optimizer_name
|
556 |
+
else:
|
557 |
+
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
558 |
+
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
559 |
+
optimizer_name = 'AdamW'
|
560 |
+
|
561 |
+
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
562 |
+
try:
|
563 |
+
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
564 |
+
except RuntimeError as e:
|
565 |
+
print("Cannot resume from saved optimizer!")
|
566 |
+
print(e)
|
567 |
+
|
568 |
+
scaler = torch.cuda.amp.GradScaler()
|
569 |
+
|
570 |
+
batch_size = ds.batch_size
|
571 |
+
gradient_step = ds.gradient_step
|
572 |
+
# n steps = batch_size * gradient_step * n image processed
|
573 |
+
steps_per_epoch = len(ds) // batch_size // gradient_step
|
574 |
+
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
575 |
+
loss_step = 0
|
576 |
+
_loss_step = 0 #internal
|
577 |
+
# size = len(ds.indexes)
|
578 |
+
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
579 |
+
loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
|
580 |
+
# losses = torch.zeros((size,))
|
581 |
+
# previous_mean_losses = [0]
|
582 |
+
# previous_mean_loss = 0
|
583 |
+
# print("Mean loss of {} elements".format(size))
|
584 |
+
|
585 |
+
steps_without_grad = 0
|
586 |
+
|
587 |
+
last_saved_file = "<none>"
|
588 |
+
last_saved_image = "<none>"
|
589 |
+
forced_filename = "<none>"
|
590 |
+
|
591 |
+
pbar = tqdm.tqdm(total=steps - initial_step)
|
592 |
+
try:
|
593 |
+
sd_hijack_checkpoint.add()
|
594 |
+
|
595 |
+
for _ in range((steps-initial_step) * gradient_step):
|
596 |
+
if scheduler.finished:
|
597 |
+
break
|
598 |
+
if shared.state.interrupted:
|
599 |
+
break
|
600 |
+
for j, batch in enumerate(dl):
|
601 |
+
# works as a drop_last=True for gradient accumulation
|
602 |
+
if j == max_steps_per_epoch:
|
603 |
+
break
|
604 |
+
scheduler.apply(optimizer, hypernetwork.step)
|
605 |
+
if scheduler.finished:
|
606 |
+
break
|
607 |
+
if shared.state.interrupted:
|
608 |
+
break
|
609 |
+
|
610 |
+
if clip_grad:
|
611 |
+
clip_grad_sched.step(hypernetwork.step)
|
612 |
+
|
613 |
+
with devices.autocast():
|
614 |
+
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
615 |
+
if use_weight:
|
616 |
+
w = batch.weight.to(devices.device, non_blocking=pin_memory)
|
617 |
+
if tag_drop_out != 0 or shuffle_tags:
|
618 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
619 |
+
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
620 |
+
shared.sd_model.cond_stage_model.to(devices.cpu)
|
621 |
+
else:
|
622 |
+
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
623 |
+
if use_weight:
|
624 |
+
loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
|
625 |
+
del w
|
626 |
+
else:
|
627 |
+
loss = shared.sd_model.forward(x, c)[0] / gradient_step
|
628 |
+
del x
|
629 |
+
del c
|
630 |
+
|
631 |
+
_loss_step += loss.item()
|
632 |
+
scaler.scale(loss).backward()
|
633 |
+
|
634 |
+
# go back until we reach gradient accumulation steps
|
635 |
+
if (j + 1) % gradient_step != 0:
|
636 |
+
continue
|
637 |
+
loss_logging.append(_loss_step)
|
638 |
+
if clip_grad:
|
639 |
+
clip_grad(weights, clip_grad_sched.learn_rate)
|
640 |
+
|
641 |
+
scaler.step(optimizer)
|
642 |
+
scaler.update()
|
643 |
+
hypernetwork.step += 1
|
644 |
+
pbar.update()
|
645 |
+
optimizer.zero_grad(set_to_none=True)
|
646 |
+
loss_step = _loss_step
|
647 |
+
_loss_step = 0
|
648 |
+
|
649 |
+
steps_done = hypernetwork.step + 1
|
650 |
+
|
651 |
+
epoch_num = hypernetwork.step // steps_per_epoch
|
652 |
+
epoch_step = hypernetwork.step % steps_per_epoch
|
653 |
+
|
654 |
+
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
|
655 |
+
pbar.set_description(description)
|
656 |
+
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
657 |
+
# Before saving, change name to match current checkpoint.
|
658 |
+
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
659 |
+
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
660 |
+
hypernetwork.optimizer_name = optimizer_name
|
661 |
+
if shared.opts.save_optimizer_state:
|
662 |
+
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
663 |
+
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
664 |
+
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
665 |
+
|
666 |
+
|
667 |
+
|
668 |
+
if shared.opts.training_enable_tensorboard:
|
669 |
+
epoch_num = hypernetwork.step // len(ds)
|
670 |
+
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
671 |
+
mean_loss = sum(loss_logging) / len(loss_logging)
|
672 |
+
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
673 |
+
|
674 |
+
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
675 |
+
"loss": f"{loss_step:.7f}",
|
676 |
+
"learn_rate": scheduler.learn_rate
|
677 |
+
})
|
678 |
+
|
679 |
+
if images_dir is not None and steps_done % create_image_every == 0:
|
680 |
+
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
681 |
+
last_saved_image = os.path.join(images_dir, forced_filename)
|
682 |
+
hypernetwork.eval()
|
683 |
+
rng_state = torch.get_rng_state()
|
684 |
+
cuda_rng_state = None
|
685 |
+
if torch.cuda.is_available():
|
686 |
+
cuda_rng_state = torch.cuda.get_rng_state_all()
|
687 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
688 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
689 |
+
|
690 |
+
p = processing.StableDiffusionProcessingTxt2Img(
|
691 |
+
sd_model=shared.sd_model,
|
692 |
+
do_not_save_grid=True,
|
693 |
+
do_not_save_samples=True,
|
694 |
+
)
|
695 |
+
|
696 |
+
p.disable_extra_networks = True
|
697 |
+
|
698 |
+
if preview_from_txt2img:
|
699 |
+
p.prompt = preview_prompt
|
700 |
+
p.negative_prompt = preview_negative_prompt
|
701 |
+
p.steps = preview_steps
|
702 |
+
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
703 |
+
p.cfg_scale = preview_cfg_scale
|
704 |
+
p.seed = preview_seed
|
705 |
+
p.width = preview_width
|
706 |
+
p.height = preview_height
|
707 |
+
else:
|
708 |
+
p.prompt = batch.cond_text[0]
|
709 |
+
p.steps = 20
|
710 |
+
p.width = training_width
|
711 |
+
p.height = training_height
|
712 |
+
|
713 |
+
preview_text = p.prompt
|
714 |
+
|
715 |
+
with closing(p):
|
716 |
+
processed = processing.process_images(p)
|
717 |
+
image = processed.images[0] if len(processed.images) > 0 else None
|
718 |
+
|
719 |
+
if unload:
|
720 |
+
shared.sd_model.cond_stage_model.to(devices.cpu)
|
721 |
+
shared.sd_model.first_stage_model.to(devices.cpu)
|
722 |
+
torch.set_rng_state(rng_state)
|
723 |
+
if torch.cuda.is_available():
|
724 |
+
torch.cuda.set_rng_state_all(cuda_rng_state)
|
725 |
+
hypernetwork.train()
|
726 |
+
if image is not None:
|
727 |
+
shared.state.assign_current_image(image)
|
728 |
+
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
729 |
+
textual_inversion.tensorboard_add_image(tensorboard_writer,
|
730 |
+
f"Validation at epoch {epoch_num}", image,
|
731 |
+
hypernetwork.step)
|
732 |
+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
733 |
+
last_saved_image += f", prompt: {preview_text}"
|
734 |
+
|
735 |
+
shared.state.job_no = hypernetwork.step
|
736 |
+
|
737 |
+
shared.state.textinfo = f"""
|
738 |
+
<p>
|
739 |
+
Loss: {loss_step:.7f}<br/>
|
740 |
+
Step: {steps_done}<br/>
|
741 |
+
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
742 |
+
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
743 |
+
Last saved image: {html.escape(last_saved_image)}<br/>
|
744 |
+
</p>
|
745 |
+
"""
|
746 |
+
except Exception:
|
747 |
+
errors.report("Exception in training hypernetwork", exc_info=True)
|
748 |
+
finally:
|
749 |
+
pbar.leave = False
|
750 |
+
pbar.close()
|
751 |
+
hypernetwork.eval()
|
752 |
+
sd_hijack_checkpoint.remove()
|
753 |
+
|
754 |
+
|
755 |
+
|
756 |
+
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
757 |
+
hypernetwork.optimizer_name = optimizer_name
|
758 |
+
if shared.opts.save_optimizer_state:
|
759 |
+
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
760 |
+
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
761 |
+
|
762 |
+
del optimizer
|
763 |
+
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
764 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
765 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
766 |
+
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
767 |
+
|
768 |
+
return hypernetwork, filename
|
769 |
+
|
770 |
+
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
771 |
+
old_hypernetwork_name = hypernetwork.name
|
772 |
+
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
773 |
+
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
774 |
+
try:
|
775 |
+
hypernetwork.sd_checkpoint = checkpoint.shorthash
|
776 |
+
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
777 |
+
hypernetwork.name = hypernetwork_name
|
778 |
+
hypernetwork.save(filename)
|
779 |
+
except:
|
780 |
+
hypernetwork.sd_checkpoint = old_sd_checkpoint
|
781 |
+
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
|
782 |
+
hypernetwork.name = old_hypernetwork_name
|
783 |
+
raise
|
modules/hypernetworks/ui.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import modules.hypernetworks.hypernetwork
|
5 |
+
from modules import devices, sd_hijack, shared
|
6 |
+
|
7 |
+
not_available = ["hardswish", "multiheadattention"]
|
8 |
+
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
|
9 |
+
|
10 |
+
|
11 |
+
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
12 |
+
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
13 |
+
|
14 |
+
return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
|
15 |
+
|
16 |
+
|
17 |
+
def train_hypernetwork(*args):
|
18 |
+
shared.loaded_hypernetworks = []
|
19 |
+
|
20 |
+
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
21 |
+
|
22 |
+
try:
|
23 |
+
sd_hijack.undo_optimizations()
|
24 |
+
|
25 |
+
hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
|
26 |
+
|
27 |
+
res = f"""
|
28 |
+
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
|
29 |
+
Hypernetwork saved to {html.escape(filename)}
|
30 |
+
"""
|
31 |
+
return res, ""
|
32 |
+
except Exception:
|
33 |
+
raise
|
34 |
+
finally:
|
35 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
36 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
37 |
+
sd_hijack.apply_optimizations()
|
38 |
+
|
modules/images.py
ADDED
@@ -0,0 +1,758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
|
5 |
+
import pytz
|
6 |
+
import io
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
from collections import namedtuple
|
10 |
+
import re
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import piexif
|
14 |
+
import piexif.helper
|
15 |
+
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
|
16 |
+
import string
|
17 |
+
import json
|
18 |
+
import hashlib
|
19 |
+
|
20 |
+
from modules import sd_samplers, shared, script_callbacks, errors
|
21 |
+
from modules.paths_internal import roboto_ttf_file
|
22 |
+
from modules.shared import opts
|
23 |
+
|
24 |
+
import modules.sd_vae as sd_vae
|
25 |
+
|
26 |
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
27 |
+
|
28 |
+
|
29 |
+
def get_font(fontsize: int):
|
30 |
+
try:
|
31 |
+
return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
|
32 |
+
except Exception:
|
33 |
+
return ImageFont.truetype(roboto_ttf_file, fontsize)
|
34 |
+
|
35 |
+
|
36 |
+
def image_grid(imgs, batch_size=1, rows=None):
|
37 |
+
if rows is None:
|
38 |
+
if opts.n_rows > 0:
|
39 |
+
rows = opts.n_rows
|
40 |
+
elif opts.n_rows == 0:
|
41 |
+
rows = batch_size
|
42 |
+
elif opts.grid_prevent_empty_spots:
|
43 |
+
rows = math.floor(math.sqrt(len(imgs)))
|
44 |
+
while len(imgs) % rows != 0:
|
45 |
+
rows -= 1
|
46 |
+
else:
|
47 |
+
rows = math.sqrt(len(imgs))
|
48 |
+
rows = round(rows)
|
49 |
+
if rows > len(imgs):
|
50 |
+
rows = len(imgs)
|
51 |
+
|
52 |
+
cols = math.ceil(len(imgs) / rows)
|
53 |
+
|
54 |
+
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
|
55 |
+
script_callbacks.image_grid_callback(params)
|
56 |
+
|
57 |
+
w, h = imgs[0].size
|
58 |
+
grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
|
59 |
+
|
60 |
+
for i, img in enumerate(params.imgs):
|
61 |
+
grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
|
62 |
+
|
63 |
+
return grid
|
64 |
+
|
65 |
+
|
66 |
+
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
67 |
+
|
68 |
+
|
69 |
+
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
70 |
+
w = image.width
|
71 |
+
h = image.height
|
72 |
+
|
73 |
+
non_overlap_width = tile_w - overlap
|
74 |
+
non_overlap_height = tile_h - overlap
|
75 |
+
|
76 |
+
cols = math.ceil((w - overlap) / non_overlap_width)
|
77 |
+
rows = math.ceil((h - overlap) / non_overlap_height)
|
78 |
+
|
79 |
+
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
80 |
+
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
81 |
+
|
82 |
+
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
83 |
+
for row in range(rows):
|
84 |
+
row_images = []
|
85 |
+
|
86 |
+
y = int(row * dy)
|
87 |
+
|
88 |
+
if y + tile_h >= h:
|
89 |
+
y = h - tile_h
|
90 |
+
|
91 |
+
for col in range(cols):
|
92 |
+
x = int(col * dx)
|
93 |
+
|
94 |
+
if x + tile_w >= w:
|
95 |
+
x = w - tile_w
|
96 |
+
|
97 |
+
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
98 |
+
|
99 |
+
row_images.append([x, tile_w, tile])
|
100 |
+
|
101 |
+
grid.tiles.append([y, tile_h, row_images])
|
102 |
+
|
103 |
+
return grid
|
104 |
+
|
105 |
+
|
106 |
+
def combine_grid(grid):
|
107 |
+
def make_mask_image(r):
|
108 |
+
r = r * 255 / grid.overlap
|
109 |
+
r = r.astype(np.uint8)
|
110 |
+
return Image.fromarray(r, 'L')
|
111 |
+
|
112 |
+
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
113 |
+
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
114 |
+
|
115 |
+
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
116 |
+
for y, h, row in grid.tiles:
|
117 |
+
combined_row = Image.new("RGB", (grid.image_w, h))
|
118 |
+
for x, w, tile in row:
|
119 |
+
if x == 0:
|
120 |
+
combined_row.paste(tile, (0, 0))
|
121 |
+
continue
|
122 |
+
|
123 |
+
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
|
124 |
+
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
|
125 |
+
|
126 |
+
if y == 0:
|
127 |
+
combined_image.paste(combined_row, (0, 0))
|
128 |
+
continue
|
129 |
+
|
130 |
+
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
|
131 |
+
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
|
132 |
+
|
133 |
+
return combined_image
|
134 |
+
|
135 |
+
|
136 |
+
class GridAnnotation:
|
137 |
+
def __init__(self, text='', is_active=True):
|
138 |
+
self.text = text
|
139 |
+
self.is_active = is_active
|
140 |
+
self.size = None
|
141 |
+
|
142 |
+
|
143 |
+
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
144 |
+
|
145 |
+
color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
|
146 |
+
color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
|
147 |
+
color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
|
148 |
+
|
149 |
+
def wrap(drawing, text, font, line_length):
|
150 |
+
lines = ['']
|
151 |
+
for word in text.split():
|
152 |
+
line = f'{lines[-1]} {word}'.strip()
|
153 |
+
if drawing.textlength(line, font=font) <= line_length:
|
154 |
+
lines[-1] = line
|
155 |
+
else:
|
156 |
+
lines.append(word)
|
157 |
+
return lines
|
158 |
+
|
159 |
+
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
160 |
+
for line in lines:
|
161 |
+
fnt = initial_fnt
|
162 |
+
fontsize = initial_fontsize
|
163 |
+
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
164 |
+
fontsize -= 1
|
165 |
+
fnt = get_font(fontsize)
|
166 |
+
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
167 |
+
|
168 |
+
if not line.is_active:
|
169 |
+
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
170 |
+
|
171 |
+
draw_y += line.size[1] + line_spacing
|
172 |
+
|
173 |
+
fontsize = (width + height) // 25
|
174 |
+
line_spacing = fontsize // 2
|
175 |
+
|
176 |
+
fnt = get_font(fontsize)
|
177 |
+
|
178 |
+
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
|
179 |
+
|
180 |
+
cols = im.width // width
|
181 |
+
rows = im.height // height
|
182 |
+
|
183 |
+
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
|
184 |
+
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
|
185 |
+
|
186 |
+
calc_img = Image.new("RGB", (1, 1), color_background)
|
187 |
+
calc_d = ImageDraw.Draw(calc_img)
|
188 |
+
|
189 |
+
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
|
190 |
+
items = [] + texts
|
191 |
+
texts.clear()
|
192 |
+
|
193 |
+
for line in items:
|
194 |
+
wrapped = wrap(calc_d, line.text, fnt, allowed_width)
|
195 |
+
texts += [GridAnnotation(x, line.is_active) for x in wrapped]
|
196 |
+
|
197 |
+
for line in texts:
|
198 |
+
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
|
199 |
+
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
200 |
+
line.allowed_width = allowed_width
|
201 |
+
|
202 |
+
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
203 |
+
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
|
204 |
+
|
205 |
+
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
206 |
+
|
207 |
+
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
|
208 |
+
|
209 |
+
for row in range(rows):
|
210 |
+
for col in range(cols):
|
211 |
+
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
|
212 |
+
result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
|
213 |
+
|
214 |
+
d = ImageDraw.Draw(result)
|
215 |
+
|
216 |
+
for col in range(cols):
|
217 |
+
x = pad_left + (width + margin) * col + width / 2
|
218 |
+
y = pad_top / 2 - hor_text_heights[col] / 2
|
219 |
+
|
220 |
+
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
|
221 |
+
|
222 |
+
for row in range(rows):
|
223 |
+
x = pad_left / 2
|
224 |
+
y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
|
225 |
+
|
226 |
+
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
|
227 |
+
|
228 |
+
return result
|
229 |
+
|
230 |
+
|
231 |
+
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
232 |
+
prompts = all_prompts[1:]
|
233 |
+
boundary = math.ceil(len(prompts) / 2)
|
234 |
+
|
235 |
+
prompts_horiz = prompts[:boundary]
|
236 |
+
prompts_vert = prompts[boundary:]
|
237 |
+
|
238 |
+
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
239 |
+
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
240 |
+
|
241 |
+
return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
|
242 |
+
|
243 |
+
|
244 |
+
def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
245 |
+
"""
|
246 |
+
Resizes an image with the specified resize_mode, width, and height.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
resize_mode: The mode to use when resizing the image.
|
250 |
+
0: Resize the image to the specified width and height.
|
251 |
+
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
252 |
+
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
253 |
+
im: The image to resize.
|
254 |
+
width: The width to resize the image to.
|
255 |
+
height: The height to resize the image to.
|
256 |
+
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
|
257 |
+
"""
|
258 |
+
|
259 |
+
upscaler_name = upscaler_name or opts.upscaler_for_img2img
|
260 |
+
|
261 |
+
def resize(im, w, h):
|
262 |
+
if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
|
263 |
+
return im.resize((w, h), resample=LANCZOS)
|
264 |
+
|
265 |
+
scale = max(w / im.width, h / im.height)
|
266 |
+
|
267 |
+
if scale > 1.0:
|
268 |
+
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
|
269 |
+
if len(upscalers) == 0:
|
270 |
+
upscaler = shared.sd_upscalers[0]
|
271 |
+
print(f"could not find upscaler named {upscaler_name or '<empty string>'}, using {upscaler.name} as a fallback")
|
272 |
+
else:
|
273 |
+
upscaler = upscalers[0]
|
274 |
+
|
275 |
+
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
276 |
+
|
277 |
+
if im.width != w or im.height != h:
|
278 |
+
im = im.resize((w, h), resample=LANCZOS)
|
279 |
+
|
280 |
+
return im
|
281 |
+
|
282 |
+
if resize_mode == 0:
|
283 |
+
res = resize(im, width, height)
|
284 |
+
|
285 |
+
elif resize_mode == 1:
|
286 |
+
ratio = width / height
|
287 |
+
src_ratio = im.width / im.height
|
288 |
+
|
289 |
+
src_w = width if ratio > src_ratio else im.width * height // im.height
|
290 |
+
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
291 |
+
|
292 |
+
resized = resize(im, src_w, src_h)
|
293 |
+
res = Image.new("RGB", (width, height))
|
294 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
295 |
+
|
296 |
+
else:
|
297 |
+
ratio = width / height
|
298 |
+
src_ratio = im.width / im.height
|
299 |
+
|
300 |
+
src_w = width if ratio < src_ratio else im.width * height // im.height
|
301 |
+
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
302 |
+
|
303 |
+
resized = resize(im, src_w, src_h)
|
304 |
+
res = Image.new("RGB", (width, height))
|
305 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
306 |
+
|
307 |
+
if ratio < src_ratio:
|
308 |
+
fill_height = height // 2 - src_h // 2
|
309 |
+
if fill_height > 0:
|
310 |
+
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
311 |
+
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
312 |
+
elif ratio > src_ratio:
|
313 |
+
fill_width = width // 2 - src_w // 2
|
314 |
+
if fill_width > 0:
|
315 |
+
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
316 |
+
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
317 |
+
|
318 |
+
return res
|
319 |
+
|
320 |
+
|
321 |
+
invalid_filename_chars = '<>:"/\\|?*\n'
|
322 |
+
invalid_filename_prefix = ' '
|
323 |
+
invalid_filename_postfix = ' .'
|
324 |
+
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
325 |
+
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
326 |
+
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
327 |
+
max_filename_part_length = 128
|
328 |
+
NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
|
329 |
+
|
330 |
+
|
331 |
+
def sanitize_filename_part(text, replace_spaces=True):
|
332 |
+
if text is None:
|
333 |
+
return None
|
334 |
+
|
335 |
+
if replace_spaces:
|
336 |
+
text = text.replace(' ', '_')
|
337 |
+
|
338 |
+
text = text.translate({ord(x): '_' for x in invalid_filename_chars})
|
339 |
+
text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
|
340 |
+
text = text.rstrip(invalid_filename_postfix)
|
341 |
+
return text
|
342 |
+
|
343 |
+
|
344 |
+
class FilenameGenerator:
|
345 |
+
def get_vae_filename(self): #get the name of the VAE file.
|
346 |
+
if sd_vae.loaded_vae_file is None:
|
347 |
+
return "NoneType"
|
348 |
+
file_name = os.path.basename(sd_vae.loaded_vae_file)
|
349 |
+
split_file_name = file_name.split('.')
|
350 |
+
if len(split_file_name) > 1 and split_file_name[0] == '':
|
351 |
+
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
|
352 |
+
else:
|
353 |
+
return split_file_name[0]
|
354 |
+
|
355 |
+
replacements = {
|
356 |
+
'seed': lambda self: self.seed if self.seed is not None else '',
|
357 |
+
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
|
358 |
+
'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
|
359 |
+
'steps': lambda self: self.p and self.p.steps,
|
360 |
+
'cfg': lambda self: self.p and self.p.cfg_scale,
|
361 |
+
'width': lambda self: self.image.width,
|
362 |
+
'height': lambda self: self.image.height,
|
363 |
+
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
364 |
+
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
|
365 |
+
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
366 |
+
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False),
|
367 |
+
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
368 |
+
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
369 |
+
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
370 |
+
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
|
371 |
+
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
372 |
+
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
373 |
+
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
374 |
+
'prompt_words': lambda self: self.prompt_words(),
|
375 |
+
'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
|
376 |
+
'batch_size': lambda self: self.p.batch_size,
|
377 |
+
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
378 |
+
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
379 |
+
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
380 |
+
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
381 |
+
'user': lambda self: self.p.user,
|
382 |
+
'vae_filename': lambda self: self.get_vae_filename(),
|
383 |
+
'none': lambda self: '', # Overrides the default so you can get just the sequence number
|
384 |
+
}
|
385 |
+
default_time_format = '%Y%m%d%H%M%S'
|
386 |
+
|
387 |
+
def __init__(self, p, seed, prompt, image, zip=False):
|
388 |
+
self.p = p
|
389 |
+
self.seed = seed
|
390 |
+
self.prompt = prompt
|
391 |
+
self.image = image
|
392 |
+
self.zip = zip
|
393 |
+
|
394 |
+
def hasprompt(self, *args):
|
395 |
+
lower = self.prompt.lower()
|
396 |
+
if self.p is None or self.prompt is None:
|
397 |
+
return None
|
398 |
+
outres = ""
|
399 |
+
for arg in args:
|
400 |
+
if arg != "":
|
401 |
+
division = arg.split("|")
|
402 |
+
expected = division[0].lower()
|
403 |
+
default = division[1] if len(division) > 1 else ""
|
404 |
+
if lower.find(expected) >= 0:
|
405 |
+
outres = f'{outres}{expected}'
|
406 |
+
else:
|
407 |
+
outres = outres if default == "" else f'{outres}{default}'
|
408 |
+
return sanitize_filename_part(outres)
|
409 |
+
|
410 |
+
def prompt_no_style(self):
|
411 |
+
if self.p is None or self.prompt is None:
|
412 |
+
return None
|
413 |
+
|
414 |
+
prompt_no_style = self.prompt
|
415 |
+
for style in shared.prompt_styles.get_style_prompts(self.p.styles):
|
416 |
+
if style:
|
417 |
+
for part in style.split("{prompt}"):
|
418 |
+
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
|
419 |
+
|
420 |
+
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
|
421 |
+
|
422 |
+
return sanitize_filename_part(prompt_no_style, replace_spaces=False)
|
423 |
+
|
424 |
+
def prompt_words(self):
|
425 |
+
words = [x for x in re_nonletters.split(self.prompt or "") if x]
|
426 |
+
if len(words) == 0:
|
427 |
+
words = ["empty"]
|
428 |
+
return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
|
429 |
+
|
430 |
+
def datetime(self, *args):
|
431 |
+
time_datetime = datetime.datetime.now()
|
432 |
+
|
433 |
+
time_format = args[0] if (args and args[0] != "") else self.default_time_format
|
434 |
+
try:
|
435 |
+
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
436 |
+
except pytz.exceptions.UnknownTimeZoneError:
|
437 |
+
time_zone = None
|
438 |
+
|
439 |
+
time_zone_time = time_datetime.astimezone(time_zone)
|
440 |
+
try:
|
441 |
+
formatted_time = time_zone_time.strftime(time_format)
|
442 |
+
except (ValueError, TypeError):
|
443 |
+
formatted_time = time_zone_time.strftime(self.default_time_format)
|
444 |
+
|
445 |
+
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
446 |
+
|
447 |
+
def apply(self, x):
|
448 |
+
res = ''
|
449 |
+
|
450 |
+
for m in re_pattern.finditer(x):
|
451 |
+
text, pattern = m.groups()
|
452 |
+
|
453 |
+
if pattern is None:
|
454 |
+
res += text
|
455 |
+
continue
|
456 |
+
|
457 |
+
pattern_args = []
|
458 |
+
while True:
|
459 |
+
m = re_pattern_arg.match(pattern)
|
460 |
+
if m is None:
|
461 |
+
break
|
462 |
+
|
463 |
+
pattern, arg = m.groups()
|
464 |
+
pattern_args.insert(0, arg)
|
465 |
+
|
466 |
+
fun = self.replacements.get(pattern.lower())
|
467 |
+
if fun is not None:
|
468 |
+
try:
|
469 |
+
replacement = fun(self, *pattern_args)
|
470 |
+
except Exception:
|
471 |
+
replacement = None
|
472 |
+
errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
|
473 |
+
|
474 |
+
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
|
475 |
+
continue
|
476 |
+
elif replacement is not None:
|
477 |
+
res += text + str(replacement)
|
478 |
+
continue
|
479 |
+
|
480 |
+
res += f'{text}[{pattern}]'
|
481 |
+
|
482 |
+
return res
|
483 |
+
|
484 |
+
|
485 |
+
def get_next_sequence_number(path, basename):
|
486 |
+
"""
|
487 |
+
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
488 |
+
|
489 |
+
The sequence starts at 0.
|
490 |
+
"""
|
491 |
+
result = -1
|
492 |
+
if basename != '':
|
493 |
+
basename = f"{basename}-"
|
494 |
+
|
495 |
+
prefix_length = len(basename)
|
496 |
+
for p in os.listdir(path):
|
497 |
+
if p.startswith(basename):
|
498 |
+
parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
499 |
+
try:
|
500 |
+
result = max(int(parts[0]), result)
|
501 |
+
except ValueError:
|
502 |
+
pass
|
503 |
+
|
504 |
+
return result + 1
|
505 |
+
|
506 |
+
|
507 |
+
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
|
508 |
+
"""
|
509 |
+
Saves image to filename, including geninfo as text information for generation info.
|
510 |
+
For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
|
511 |
+
For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
|
512 |
+
"""
|
513 |
+
|
514 |
+
if extension is None:
|
515 |
+
extension = os.path.splitext(filename)[1]
|
516 |
+
|
517 |
+
image_format = Image.registered_extensions()[extension]
|
518 |
+
|
519 |
+
if extension.lower() == '.png':
|
520 |
+
existing_pnginfo = existing_pnginfo or {}
|
521 |
+
if opts.enable_pnginfo:
|
522 |
+
existing_pnginfo[pnginfo_section_name] = geninfo
|
523 |
+
|
524 |
+
if opts.enable_pnginfo:
|
525 |
+
pnginfo_data = PngImagePlugin.PngInfo()
|
526 |
+
for k, v in (existing_pnginfo or {}).items():
|
527 |
+
pnginfo_data.add_text(k, str(v))
|
528 |
+
else:
|
529 |
+
pnginfo_data = None
|
530 |
+
|
531 |
+
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
532 |
+
|
533 |
+
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
534 |
+
if image.mode == 'RGBA':
|
535 |
+
image = image.convert("RGB")
|
536 |
+
elif image.mode == 'I;16':
|
537 |
+
image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
538 |
+
|
539 |
+
image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
540 |
+
|
541 |
+
if opts.enable_pnginfo and geninfo is not None:
|
542 |
+
exif_bytes = piexif.dump({
|
543 |
+
"Exif": {
|
544 |
+
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
|
545 |
+
},
|
546 |
+
})
|
547 |
+
|
548 |
+
piexif.insert(exif_bytes, filename)
|
549 |
+
else:
|
550 |
+
image.save(filename, format=image_format, quality=opts.jpeg_quality)
|
551 |
+
|
552 |
+
|
553 |
+
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
554 |
+
"""Save an image.
|
555 |
+
|
556 |
+
Args:
|
557 |
+
image (`PIL.Image`):
|
558 |
+
The image to be saved.
|
559 |
+
path (`str`):
|
560 |
+
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
|
561 |
+
basename (`str`):
|
562 |
+
The base filename which will be applied to `filename pattern`.
|
563 |
+
seed, prompt, short_filename,
|
564 |
+
extension (`str`):
|
565 |
+
Image file extension, default is `png`.
|
566 |
+
pngsectionname (`str`):
|
567 |
+
Specify the name of the section which `info` will be saved in.
|
568 |
+
info (`str` or `PngImagePlugin.iTXt`):
|
569 |
+
PNG info chunks.
|
570 |
+
existing_info (`dict`):
|
571 |
+
Additional PNG info. `existing_info == {pngsectionname: info, ...}`
|
572 |
+
no_prompt:
|
573 |
+
TODO I don't know its meaning.
|
574 |
+
p (`StableDiffusionProcessing`)
|
575 |
+
forced_filename (`str`):
|
576 |
+
If specified, `basename` and filename pattern will be ignored.
|
577 |
+
save_to_dirs (bool):
|
578 |
+
If true, the image will be saved into a subdirectory of `path`.
|
579 |
+
|
580 |
+
Returns: (fullfn, txt_fullfn)
|
581 |
+
fullfn (`str`):
|
582 |
+
The full path of the saved imaged.
|
583 |
+
txt_fullfn (`str` or None):
|
584 |
+
If a text file is saved for this image, this will be its full path. Otherwise None.
|
585 |
+
"""
|
586 |
+
namegen = FilenameGenerator(p, seed, prompt, image)
|
587 |
+
|
588 |
+
if save_to_dirs is None:
|
589 |
+
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
590 |
+
|
591 |
+
if save_to_dirs:
|
592 |
+
dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
|
593 |
+
path = os.path.join(path, dirname)
|
594 |
+
|
595 |
+
os.makedirs(path, exist_ok=True)
|
596 |
+
|
597 |
+
if forced_filename is None:
|
598 |
+
if short_filename or seed is None:
|
599 |
+
file_decoration = ""
|
600 |
+
elif opts.save_to_dirs:
|
601 |
+
file_decoration = opts.samples_filename_pattern or "[seed]"
|
602 |
+
else:
|
603 |
+
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
|
604 |
+
|
605 |
+
file_decoration = namegen.apply(file_decoration) + suffix
|
606 |
+
|
607 |
+
add_number = opts.save_images_add_number or file_decoration == ''
|
608 |
+
|
609 |
+
if file_decoration != "" and add_number:
|
610 |
+
file_decoration = f"-{file_decoration}"
|
611 |
+
|
612 |
+
if add_number:
|
613 |
+
basecount = get_next_sequence_number(path, basename)
|
614 |
+
fullfn = None
|
615 |
+
for i in range(500):
|
616 |
+
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
617 |
+
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
618 |
+
if not os.path.exists(fullfn):
|
619 |
+
break
|
620 |
+
else:
|
621 |
+
fullfn = os.path.join(path, f"{file_decoration}.{extension}")
|
622 |
+
else:
|
623 |
+
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
|
624 |
+
|
625 |
+
pnginfo = existing_info or {}
|
626 |
+
if info is not None:
|
627 |
+
pnginfo[pnginfo_section_name] = info
|
628 |
+
|
629 |
+
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
|
630 |
+
script_callbacks.before_image_saved_callback(params)
|
631 |
+
|
632 |
+
image = params.image
|
633 |
+
fullfn = params.filename
|
634 |
+
info = params.pnginfo.get(pnginfo_section_name, None)
|
635 |
+
|
636 |
+
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
637 |
+
"""
|
638 |
+
save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
639 |
+
"""
|
640 |
+
temp_file_path = f"{filename_without_extension}.tmp"
|
641 |
+
|
642 |
+
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
|
643 |
+
|
644 |
+
os.replace(temp_file_path, filename_without_extension + extension)
|
645 |
+
|
646 |
+
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
647 |
+
if hasattr(os, 'statvfs'):
|
648 |
+
max_name_len = os.statvfs(path).f_namemax
|
649 |
+
fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
|
650 |
+
params.filename = fullfn_without_extension + extension
|
651 |
+
fullfn = params.filename
|
652 |
+
_atomically_save_image(image, fullfn_without_extension, extension)
|
653 |
+
|
654 |
+
image.already_saved_as = fullfn
|
655 |
+
|
656 |
+
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
|
657 |
+
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
|
658 |
+
ratio = image.width / image.height
|
659 |
+
resize_to = None
|
660 |
+
if oversize and ratio > 1:
|
661 |
+
resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
|
662 |
+
elif oversize:
|
663 |
+
resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
|
664 |
+
|
665 |
+
if resize_to is not None:
|
666 |
+
try:
|
667 |
+
# Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
|
668 |
+
image = image.resize(resize_to, LANCZOS)
|
669 |
+
except Exception:
|
670 |
+
image = image.resize(resize_to)
|
671 |
+
try:
|
672 |
+
_atomically_save_image(image, fullfn_without_extension, ".jpg")
|
673 |
+
except Exception as e:
|
674 |
+
errors.display(e, "saving image as downscaled JPG")
|
675 |
+
|
676 |
+
if opts.save_txt and info is not None:
|
677 |
+
txt_fullfn = f"{fullfn_without_extension}.txt"
|
678 |
+
with open(txt_fullfn, "w", encoding="utf8") as file:
|
679 |
+
file.write(f"{info}\n")
|
680 |
+
else:
|
681 |
+
txt_fullfn = None
|
682 |
+
|
683 |
+
script_callbacks.image_saved_callback(params)
|
684 |
+
|
685 |
+
return fullfn, txt_fullfn
|
686 |
+
|
687 |
+
|
688 |
+
IGNORED_INFO_KEYS = {
|
689 |
+
'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
690 |
+
'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
|
691 |
+
'icc_profile', 'chromaticity', 'photoshop',
|
692 |
+
}
|
693 |
+
|
694 |
+
|
695 |
+
def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
|
696 |
+
items = (image.info or {}).copy()
|
697 |
+
|
698 |
+
geninfo = items.pop('parameters', None)
|
699 |
+
|
700 |
+
if "exif" in items:
|
701 |
+
exif = piexif.load(items["exif"])
|
702 |
+
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
703 |
+
try:
|
704 |
+
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
705 |
+
except ValueError:
|
706 |
+
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
707 |
+
|
708 |
+
if exif_comment:
|
709 |
+
items['exif comment'] = exif_comment
|
710 |
+
geninfo = exif_comment
|
711 |
+
|
712 |
+
for field in IGNORED_INFO_KEYS:
|
713 |
+
items.pop(field, None)
|
714 |
+
|
715 |
+
if items.get("Software", None) == "NovelAI":
|
716 |
+
try:
|
717 |
+
json_info = json.loads(items["Comment"])
|
718 |
+
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
|
719 |
+
|
720 |
+
geninfo = f"""{items["Description"]}
|
721 |
+
Negative prompt: {json_info["uc"]}
|
722 |
+
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
723 |
+
except Exception:
|
724 |
+
errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
|
725 |
+
|
726 |
+
return geninfo, items
|
727 |
+
|
728 |
+
|
729 |
+
def image_data(data):
|
730 |
+
import gradio as gr
|
731 |
+
|
732 |
+
try:
|
733 |
+
image = Image.open(io.BytesIO(data))
|
734 |
+
textinfo, _ = read_info_from_image(image)
|
735 |
+
return textinfo, None
|
736 |
+
except Exception:
|
737 |
+
pass
|
738 |
+
|
739 |
+
try:
|
740 |
+
text = data.decode('utf8')
|
741 |
+
assert len(text) < 10000
|
742 |
+
return text, None
|
743 |
+
|
744 |
+
except Exception:
|
745 |
+
pass
|
746 |
+
|
747 |
+
return gr.update(), None
|
748 |
+
|
749 |
+
|
750 |
+
def flatten(img, bgcolor):
|
751 |
+
"""replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
|
752 |
+
|
753 |
+
if img.mode == "RGBA":
|
754 |
+
background = Image.new('RGBA', img.size, bgcolor)
|
755 |
+
background.paste(img, mask=img)
|
756 |
+
img = background
|
757 |
+
|
758 |
+
return img.convert('RGB')
|
modules/img2img.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from contextlib import closing
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from modules import sd_samplers, images as imgutil
|
10 |
+
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
11 |
+
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
12 |
+
from modules.shared import opts, state
|
13 |
+
from modules.images import save_image
|
14 |
+
import modules.shared as shared
|
15 |
+
import modules.processing as processing
|
16 |
+
from modules.ui import plaintext_to_html
|
17 |
+
import modules.scripts
|
18 |
+
|
19 |
+
|
20 |
+
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
21 |
+
processing.fix_seed(p)
|
22 |
+
|
23 |
+
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
|
24 |
+
|
25 |
+
is_inpaint_batch = False
|
26 |
+
if inpaint_mask_dir:
|
27 |
+
inpaint_masks = shared.listfiles(inpaint_mask_dir)
|
28 |
+
is_inpaint_batch = bool(inpaint_masks)
|
29 |
+
|
30 |
+
if is_inpaint_batch:
|
31 |
+
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
32 |
+
|
33 |
+
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
34 |
+
|
35 |
+
save_normally = output_dir == ''
|
36 |
+
|
37 |
+
p.do_not_save_grid = True
|
38 |
+
p.do_not_save_samples = not save_normally
|
39 |
+
|
40 |
+
state.job_count = len(images) * p.n_iter
|
41 |
+
|
42 |
+
# extract "default" params to use in case getting png info fails
|
43 |
+
prompt = p.prompt
|
44 |
+
negative_prompt = p.negative_prompt
|
45 |
+
seed = p.seed
|
46 |
+
cfg_scale = p.cfg_scale
|
47 |
+
sampler_name = p.sampler_name
|
48 |
+
steps = p.steps
|
49 |
+
|
50 |
+
for i, image in enumerate(images):
|
51 |
+
state.job = f"{i+1} out of {len(images)}"
|
52 |
+
if state.skipped:
|
53 |
+
state.skipped = False
|
54 |
+
|
55 |
+
if state.interrupted:
|
56 |
+
break
|
57 |
+
|
58 |
+
p.filename = os.path.basename(image)
|
59 |
+
|
60 |
+
try:
|
61 |
+
img = Image.open(image)
|
62 |
+
except UnidentifiedImageError as e:
|
63 |
+
print(e)
|
64 |
+
continue
|
65 |
+
# Use the EXIF orientation of photos taken by smartphones.
|
66 |
+
img = ImageOps.exif_transpose(img)
|
67 |
+
|
68 |
+
if to_scale:
|
69 |
+
p.width = int(img.width * scale_by)
|
70 |
+
p.height = int(img.height * scale_by)
|
71 |
+
|
72 |
+
p.init_images = [img] * p.batch_size
|
73 |
+
|
74 |
+
image_path = Path(image)
|
75 |
+
if is_inpaint_batch:
|
76 |
+
# try to find corresponding mask for an image using simple filename matching
|
77 |
+
if len(inpaint_masks) == 1:
|
78 |
+
mask_image_path = inpaint_masks[0]
|
79 |
+
else:
|
80 |
+
# try to find corresponding mask for an image using simple filename matching
|
81 |
+
mask_image_dir = Path(inpaint_mask_dir)
|
82 |
+
masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
|
83 |
+
|
84 |
+
if len(masks_found) == 0:
|
85 |
+
print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
|
86 |
+
continue
|
87 |
+
|
88 |
+
# it should contain only 1 matching mask
|
89 |
+
# otherwise user has many masks with the same name but different extensions
|
90 |
+
mask_image_path = masks_found[0]
|
91 |
+
|
92 |
+
mask_image = Image.open(mask_image_path)
|
93 |
+
p.image_mask = mask_image
|
94 |
+
|
95 |
+
if use_png_info:
|
96 |
+
try:
|
97 |
+
info_img = img
|
98 |
+
if png_info_dir:
|
99 |
+
info_img_path = os.path.join(png_info_dir, os.path.basename(image))
|
100 |
+
info_img = Image.open(info_img_path)
|
101 |
+
geninfo, _ = imgutil.read_info_from_image(info_img)
|
102 |
+
parsed_parameters = parse_generation_parameters(geninfo)
|
103 |
+
parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
|
104 |
+
except Exception:
|
105 |
+
parsed_parameters = {}
|
106 |
+
|
107 |
+
p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
|
108 |
+
p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
|
109 |
+
p.seed = int(parsed_parameters.get("Seed", seed))
|
110 |
+
p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
|
111 |
+
p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
|
112 |
+
p.steps = int(parsed_parameters.get("Steps", steps))
|
113 |
+
|
114 |
+
proc = modules.scripts.scripts_img2img.run(p, *args)
|
115 |
+
if proc is None:
|
116 |
+
proc = process_images(p)
|
117 |
+
|
118 |
+
for n, processed_image in enumerate(proc.images):
|
119 |
+
filename = image_path.stem
|
120 |
+
infotext = proc.infotext(p, n)
|
121 |
+
relpath = os.path.dirname(os.path.relpath(image, input_dir))
|
122 |
+
|
123 |
+
if n > 0:
|
124 |
+
filename += f"-{n}"
|
125 |
+
|
126 |
+
if not save_normally:
|
127 |
+
os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
|
128 |
+
if processed_image.mode == 'RGBA':
|
129 |
+
processed_image = processed_image.convert("RGB")
|
130 |
+
save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
|
131 |
+
|
132 |
+
|
133 |
+
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
134 |
+
override_settings = create_override_settings_dict(override_settings_texts)
|
135 |
+
|
136 |
+
is_batch = mode == 5
|
137 |
+
|
138 |
+
if mode == 0: # img2img
|
139 |
+
image = init_img.convert("RGB")
|
140 |
+
mask = None
|
141 |
+
elif mode == 1: # img2img sketch
|
142 |
+
image = sketch.convert("RGB")
|
143 |
+
mask = None
|
144 |
+
elif mode == 2: # inpaint
|
145 |
+
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
146 |
+
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
147 |
+
mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
|
148 |
+
mask = ImageChops.lighter(alpha_mask, mask).convert('L')
|
149 |
+
image = image.convert("RGB")
|
150 |
+
elif mode == 3: # inpaint sketch
|
151 |
+
image = inpaint_color_sketch
|
152 |
+
orig = inpaint_color_sketch_orig or inpaint_color_sketch
|
153 |
+
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
154 |
+
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
155 |
+
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
156 |
+
blur = ImageFilter.GaussianBlur(mask_blur)
|
157 |
+
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
158 |
+
image = image.convert("RGB")
|
159 |
+
elif mode == 4: # inpaint upload mask
|
160 |
+
image = init_img_inpaint
|
161 |
+
mask = init_mask_inpaint
|
162 |
+
else:
|
163 |
+
image = None
|
164 |
+
mask = None
|
165 |
+
|
166 |
+
# Use the EXIF orientation of photos taken by smartphones.
|
167 |
+
if image is not None:
|
168 |
+
image = ImageOps.exif_transpose(image)
|
169 |
+
|
170 |
+
if selected_scale_tab == 1 and not is_batch:
|
171 |
+
assert image, "Can't scale by because no image is selected"
|
172 |
+
|
173 |
+
width = int(image.width * scale_by)
|
174 |
+
height = int(image.height * scale_by)
|
175 |
+
|
176 |
+
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
177 |
+
|
178 |
+
p = StableDiffusionProcessingImg2Img(
|
179 |
+
sd_model=shared.sd_model,
|
180 |
+
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
|
181 |
+
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
|
182 |
+
prompt=prompt,
|
183 |
+
negative_prompt=negative_prompt,
|
184 |
+
styles=prompt_styles,
|
185 |
+
seed=seed,
|
186 |
+
subseed=subseed,
|
187 |
+
subseed_strength=subseed_strength,
|
188 |
+
seed_resize_from_h=seed_resize_from_h,
|
189 |
+
seed_resize_from_w=seed_resize_from_w,
|
190 |
+
seed_enable_extras=seed_enable_extras,
|
191 |
+
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
|
192 |
+
batch_size=batch_size,
|
193 |
+
n_iter=n_iter,
|
194 |
+
steps=steps,
|
195 |
+
cfg_scale=cfg_scale,
|
196 |
+
width=width,
|
197 |
+
height=height,
|
198 |
+
restore_faces=restore_faces,
|
199 |
+
tiling=tiling,
|
200 |
+
init_images=[image],
|
201 |
+
mask=mask,
|
202 |
+
mask_blur=mask_blur,
|
203 |
+
inpainting_fill=inpainting_fill,
|
204 |
+
resize_mode=resize_mode,
|
205 |
+
denoising_strength=denoising_strength,
|
206 |
+
image_cfg_scale=image_cfg_scale,
|
207 |
+
inpaint_full_res=inpaint_full_res,
|
208 |
+
inpaint_full_res_padding=inpaint_full_res_padding,
|
209 |
+
inpainting_mask_invert=inpainting_mask_invert,
|
210 |
+
override_settings=override_settings,
|
211 |
+
)
|
212 |
+
|
213 |
+
p.scripts = modules.scripts.scripts_img2img
|
214 |
+
p.script_args = args
|
215 |
+
|
216 |
+
p.user = request.username
|
217 |
+
|
218 |
+
if shared.cmd_opts.enable_console_prompts:
|
219 |
+
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
220 |
+
|
221 |
+
if mask:
|
222 |
+
p.extra_generation_params["Mask blur"] = mask_blur
|
223 |
+
|
224 |
+
with closing(p):
|
225 |
+
if is_batch:
|
226 |
+
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
227 |
+
|
228 |
+
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
|
229 |
+
|
230 |
+
processed = Processed(p, [], p.seed, "")
|
231 |
+
else:
|
232 |
+
processed = modules.scripts.scripts_img2img.run(p, *args)
|
233 |
+
if processed is None:
|
234 |
+
processed = process_images(p)
|
235 |
+
|
236 |
+
shared.total_tqdm.clear()
|
237 |
+
|
238 |
+
generation_info_js = processed.js()
|
239 |
+
if opts.samples_log_stdout:
|
240 |
+
print(generation_info_js)
|
241 |
+
|
242 |
+
if opts.do_not_show_images:
|
243 |
+
processed.images = []
|
244 |
+
|
245 |
+
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
modules/import_hook.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
|
4 |
+
if "--xformers" not in "".join(sys.argv):
|
5 |
+
sys.modules["xformers"] = None
|
modules/interrogate.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from collections import namedtuple
|
4 |
+
from pathlib import Path
|
5 |
+
import re
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.hub
|
9 |
+
|
10 |
+
from torchvision import transforms
|
11 |
+
from torchvision.transforms.functional import InterpolationMode
|
12 |
+
|
13 |
+
from modules import devices, paths, shared, lowvram, modelloader, errors
|
14 |
+
|
15 |
+
blip_image_eval_size = 384
|
16 |
+
clip_model_name = 'ViT-L/14'
|
17 |
+
|
18 |
+
Category = namedtuple("Category", ["name", "topn", "items"])
|
19 |
+
|
20 |
+
re_topn = re.compile(r"\.top(\d+)\.")
|
21 |
+
|
22 |
+
def category_types():
|
23 |
+
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
24 |
+
|
25 |
+
|
26 |
+
def download_default_clip_interrogate_categories(content_dir):
|
27 |
+
print("Downloading CLIP categories...")
|
28 |
+
|
29 |
+
tmpdir = f"{content_dir}_tmp"
|
30 |
+
category_types = ["artists", "flavors", "mediums", "movements"]
|
31 |
+
|
32 |
+
try:
|
33 |
+
os.makedirs(tmpdir, exist_ok=True)
|
34 |
+
for category_type in category_types:
|
35 |
+
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
36 |
+
os.rename(tmpdir, content_dir)
|
37 |
+
|
38 |
+
except Exception as e:
|
39 |
+
errors.display(e, "downloading default CLIP interrogate categories")
|
40 |
+
finally:
|
41 |
+
if os.path.exists(tmpdir):
|
42 |
+
os.removedirs(tmpdir)
|
43 |
+
|
44 |
+
|
45 |
+
class InterrogateModels:
|
46 |
+
blip_model = None
|
47 |
+
clip_model = None
|
48 |
+
clip_preprocess = None
|
49 |
+
dtype = None
|
50 |
+
running_on_cpu = None
|
51 |
+
|
52 |
+
def __init__(self, content_dir):
|
53 |
+
self.loaded_categories = None
|
54 |
+
self.skip_categories = []
|
55 |
+
self.content_dir = content_dir
|
56 |
+
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
57 |
+
|
58 |
+
def categories(self):
|
59 |
+
if not os.path.exists(self.content_dir):
|
60 |
+
download_default_clip_interrogate_categories(self.content_dir)
|
61 |
+
|
62 |
+
if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
|
63 |
+
return self.loaded_categories
|
64 |
+
|
65 |
+
self.loaded_categories = []
|
66 |
+
|
67 |
+
if os.path.exists(self.content_dir):
|
68 |
+
self.skip_categories = shared.opts.interrogate_clip_skip_categories
|
69 |
+
category_types = []
|
70 |
+
for filename in Path(self.content_dir).glob('*.txt'):
|
71 |
+
category_types.append(filename.stem)
|
72 |
+
if filename.stem in self.skip_categories:
|
73 |
+
continue
|
74 |
+
m = re_topn.search(filename.stem)
|
75 |
+
topn = 1 if m is None else int(m.group(1))
|
76 |
+
with open(filename, "r", encoding="utf8") as file:
|
77 |
+
lines = [x.strip() for x in file.readlines()]
|
78 |
+
|
79 |
+
self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
|
80 |
+
|
81 |
+
return self.loaded_categories
|
82 |
+
|
83 |
+
def create_fake_fairscale(self):
|
84 |
+
class FakeFairscale:
|
85 |
+
def checkpoint_wrapper(self):
|
86 |
+
pass
|
87 |
+
|
88 |
+
sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
|
89 |
+
|
90 |
+
def load_blip_model(self):
|
91 |
+
self.create_fake_fairscale()
|
92 |
+
import models.blip
|
93 |
+
|
94 |
+
files = modelloader.load_models(
|
95 |
+
model_path=os.path.join(paths.models_path, "BLIP"),
|
96 |
+
model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
|
97 |
+
ext_filter=[".pth"],
|
98 |
+
download_name='model_base_caption_capfilt_large.pth',
|
99 |
+
)
|
100 |
+
|
101 |
+
blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
102 |
+
blip_model.eval()
|
103 |
+
|
104 |
+
return blip_model
|
105 |
+
|
106 |
+
def load_clip_model(self):
|
107 |
+
import clip
|
108 |
+
|
109 |
+
if self.running_on_cpu:
|
110 |
+
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
|
111 |
+
else:
|
112 |
+
model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
|
113 |
+
|
114 |
+
model.eval()
|
115 |
+
model = model.to(devices.device_interrogate)
|
116 |
+
|
117 |
+
return model, preprocess
|
118 |
+
|
119 |
+
def load(self):
|
120 |
+
if self.blip_model is None:
|
121 |
+
self.blip_model = self.load_blip_model()
|
122 |
+
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
123 |
+
self.blip_model = self.blip_model.half()
|
124 |
+
|
125 |
+
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
126 |
+
|
127 |
+
if self.clip_model is None:
|
128 |
+
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
129 |
+
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
130 |
+
self.clip_model = self.clip_model.half()
|
131 |
+
|
132 |
+
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
133 |
+
|
134 |
+
self.dtype = next(self.clip_model.parameters()).dtype
|
135 |
+
|
136 |
+
def send_clip_to_ram(self):
|
137 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
138 |
+
if self.clip_model is not None:
|
139 |
+
self.clip_model = self.clip_model.to(devices.cpu)
|
140 |
+
|
141 |
+
def send_blip_to_ram(self):
|
142 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
143 |
+
if self.blip_model is not None:
|
144 |
+
self.blip_model = self.blip_model.to(devices.cpu)
|
145 |
+
|
146 |
+
def unload(self):
|
147 |
+
self.send_clip_to_ram()
|
148 |
+
self.send_blip_to_ram()
|
149 |
+
|
150 |
+
devices.torch_gc()
|
151 |
+
|
152 |
+
def rank(self, image_features, text_array, top_count=1):
|
153 |
+
import clip
|
154 |
+
|
155 |
+
devices.torch_gc()
|
156 |
+
|
157 |
+
if shared.opts.interrogate_clip_dict_limit != 0:
|
158 |
+
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
159 |
+
|
160 |
+
top_count = min(top_count, len(text_array))
|
161 |
+
text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
|
162 |
+
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
163 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
164 |
+
|
165 |
+
similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
|
166 |
+
for i in range(image_features.shape[0]):
|
167 |
+
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
168 |
+
similarity /= image_features.shape[0]
|
169 |
+
|
170 |
+
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
171 |
+
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
172 |
+
|
173 |
+
def generate_caption(self, pil_image):
|
174 |
+
gpu_image = transforms.Compose([
|
175 |
+
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
176 |
+
transforms.ToTensor(),
|
177 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
178 |
+
])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
179 |
+
|
180 |
+
with torch.no_grad():
|
181 |
+
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
182 |
+
|
183 |
+
return caption[0]
|
184 |
+
|
185 |
+
def interrogate(self, pil_image):
|
186 |
+
res = ""
|
187 |
+
shared.state.begin(job="interrogate")
|
188 |
+
try:
|
189 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
190 |
+
lowvram.send_everything_to_cpu()
|
191 |
+
devices.torch_gc()
|
192 |
+
|
193 |
+
self.load()
|
194 |
+
|
195 |
+
caption = self.generate_caption(pil_image)
|
196 |
+
self.send_blip_to_ram()
|
197 |
+
devices.torch_gc()
|
198 |
+
|
199 |
+
res = caption
|
200 |
+
|
201 |
+
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
202 |
+
|
203 |
+
with torch.no_grad(), devices.autocast():
|
204 |
+
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
205 |
+
|
206 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
207 |
+
|
208 |
+
for cat in self.categories():
|
209 |
+
matches = self.rank(image_features, cat.items, top_count=cat.topn)
|
210 |
+
for match, score in matches:
|
211 |
+
if shared.opts.interrogate_return_ranks:
|
212 |
+
res += f", ({match}:{score/100:.3f})"
|
213 |
+
else:
|
214 |
+
res += f", {match}"
|
215 |
+
|
216 |
+
except Exception:
|
217 |
+
errors.report("Error interrogating", exc_info=True)
|
218 |
+
res += "<error>"
|
219 |
+
|
220 |
+
self.unload()
|
221 |
+
shared.state.end()
|
222 |
+
|
223 |
+
return res
|
modules/launch_utils.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this scripts installs necessary requirements and launches main program in webui.py
|
2 |
+
import re
|
3 |
+
import subprocess
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import importlib.util
|
7 |
+
import platform
|
8 |
+
import json
|
9 |
+
from functools import lru_cache
|
10 |
+
|
11 |
+
from modules import cmd_args, errors
|
12 |
+
from modules.paths_internal import script_path, extensions_dir
|
13 |
+
from modules.timer import startup_timer
|
14 |
+
|
15 |
+
args, _ = cmd_args.parser.parse_known_args()
|
16 |
+
|
17 |
+
python = sys.executable
|
18 |
+
git = os.environ.get('GIT', "git")
|
19 |
+
index_url = os.environ.get('INDEX_URL', "")
|
20 |
+
dir_repos = "repositories"
|
21 |
+
|
22 |
+
# Whether to default to printing command output
|
23 |
+
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
24 |
+
|
25 |
+
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
26 |
+
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
27 |
+
|
28 |
+
|
29 |
+
def check_python_version():
|
30 |
+
is_windows = platform.system() == "Windows"
|
31 |
+
major = sys.version_info.major
|
32 |
+
minor = sys.version_info.minor
|
33 |
+
micro = sys.version_info.micro
|
34 |
+
|
35 |
+
if is_windows:
|
36 |
+
supported_minors = [10]
|
37 |
+
else:
|
38 |
+
supported_minors = [7, 8, 9, 10, 11]
|
39 |
+
|
40 |
+
if not (major == 3 and minor in supported_minors):
|
41 |
+
import modules.errors
|
42 |
+
|
43 |
+
modules.errors.print_error_explanation(f"""
|
44 |
+
INCOMPATIBLE PYTHON VERSION
|
45 |
+
|
46 |
+
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
47 |
+
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
48 |
+
or any other error regarding unsuccessful package (library) installation,
|
49 |
+
please downgrade (or upgrade) to the latest version of 3.10 Python
|
50 |
+
and delete current Python and "venv" folder in WebUI's directory.
|
51 |
+
|
52 |
+
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
|
53 |
+
|
54 |
+
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
55 |
+
|
56 |
+
Use --skip-python-version-check to suppress this warning.
|
57 |
+
""")
|
58 |
+
|
59 |
+
|
60 |
+
@lru_cache()
|
61 |
+
def commit_hash():
|
62 |
+
try:
|
63 |
+
return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
64 |
+
except Exception:
|
65 |
+
return "<none>"
|
66 |
+
|
67 |
+
|
68 |
+
@lru_cache()
|
69 |
+
def git_tag():
|
70 |
+
try:
|
71 |
+
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
72 |
+
except Exception:
|
73 |
+
try:
|
74 |
+
|
75 |
+
changelog_md = os.path.join(os.path.dirname(os.path.dirname(__file__)), "CHANGELOG.md")
|
76 |
+
with open(changelog_md, "r", encoding="utf-8") as file:
|
77 |
+
line = next((line.strip() for line in file if line.strip()), "<none>")
|
78 |
+
line = line.replace("## ", "")
|
79 |
+
return line
|
80 |
+
except Exception:
|
81 |
+
return "<none>"
|
82 |
+
|
83 |
+
|
84 |
+
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
85 |
+
if desc is not None:
|
86 |
+
print(desc)
|
87 |
+
|
88 |
+
run_kwargs = {
|
89 |
+
"args": command,
|
90 |
+
"shell": True,
|
91 |
+
"env": os.environ if custom_env is None else custom_env,
|
92 |
+
"encoding": 'utf8',
|
93 |
+
"errors": 'ignore',
|
94 |
+
}
|
95 |
+
|
96 |
+
if not live:
|
97 |
+
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
98 |
+
|
99 |
+
result = subprocess.run(**run_kwargs)
|
100 |
+
|
101 |
+
if result.returncode != 0:
|
102 |
+
error_bits = [
|
103 |
+
f"{errdesc or 'Error running command'}.",
|
104 |
+
f"Command: {command}",
|
105 |
+
f"Error code: {result.returncode}",
|
106 |
+
]
|
107 |
+
if result.stdout:
|
108 |
+
error_bits.append(f"stdout: {result.stdout}")
|
109 |
+
if result.stderr:
|
110 |
+
error_bits.append(f"stderr: {result.stderr}")
|
111 |
+
raise RuntimeError("\n".join(error_bits))
|
112 |
+
|
113 |
+
return (result.stdout or "")
|
114 |
+
|
115 |
+
|
116 |
+
def is_installed(package):
|
117 |
+
try:
|
118 |
+
spec = importlib.util.find_spec(package)
|
119 |
+
except ModuleNotFoundError:
|
120 |
+
return False
|
121 |
+
|
122 |
+
return spec is not None
|
123 |
+
|
124 |
+
|
125 |
+
def repo_dir(name):
|
126 |
+
return os.path.join(script_path, dir_repos, name)
|
127 |
+
|
128 |
+
|
129 |
+
def run_pip(command, desc=None, live=default_command_live):
|
130 |
+
if args.skip_install:
|
131 |
+
return
|
132 |
+
|
133 |
+
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
134 |
+
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
135 |
+
|
136 |
+
|
137 |
+
def check_run_python(code: str) -> bool:
|
138 |
+
result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
|
139 |
+
return result.returncode == 0
|
140 |
+
|
141 |
+
|
142 |
+
def git_clone(url, dir, name, commithash=None):
|
143 |
+
# TODO clone into temporary dir and move if successful
|
144 |
+
|
145 |
+
if os.path.exists(dir):
|
146 |
+
if commithash is None:
|
147 |
+
return
|
148 |
+
|
149 |
+
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
150 |
+
if current_hash == commithash:
|
151 |
+
return
|
152 |
+
|
153 |
+
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
154 |
+
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
155 |
+
return
|
156 |
+
|
157 |
+
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
158 |
+
|
159 |
+
if commithash is not None:
|
160 |
+
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
161 |
+
|
162 |
+
|
163 |
+
def git_pull_recursive(dir):
|
164 |
+
for subdir, _, _ in os.walk(dir):
|
165 |
+
if os.path.exists(os.path.join(subdir, '.git')):
|
166 |
+
try:
|
167 |
+
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
|
168 |
+
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
|
169 |
+
except subprocess.CalledProcessError as e:
|
170 |
+
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
|
171 |
+
|
172 |
+
|
173 |
+
def version_check(commit):
|
174 |
+
try:
|
175 |
+
import requests
|
176 |
+
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
|
177 |
+
if commit != "<none>" and commits['commit']['sha'] != commit:
|
178 |
+
print("--------------------------------------------------------")
|
179 |
+
print("| You are not up to date with the most recent release. |")
|
180 |
+
print("| Consider running `git pull` to update. |")
|
181 |
+
print("--------------------------------------------------------")
|
182 |
+
elif commits['commit']['sha'] == commit:
|
183 |
+
print("You are up to date with the most recent release.")
|
184 |
+
else:
|
185 |
+
print("Not a git clone, can't perform version check.")
|
186 |
+
except Exception as e:
|
187 |
+
print("version check failed", e)
|
188 |
+
|
189 |
+
|
190 |
+
def run_extension_installer(extension_dir):
|
191 |
+
path_installer = os.path.join(extension_dir, "install.py")
|
192 |
+
if not os.path.isfile(path_installer):
|
193 |
+
return
|
194 |
+
|
195 |
+
try:
|
196 |
+
env = os.environ.copy()
|
197 |
+
env['PYTHONPATH'] = f"{os.path.abspath('.')}{os.pathsep}{env.get('PYTHONPATH', '')}"
|
198 |
+
|
199 |
+
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
|
200 |
+
except Exception as e:
|
201 |
+
errors.report(str(e))
|
202 |
+
|
203 |
+
|
204 |
+
def list_extensions(settings_file):
|
205 |
+
settings = {}
|
206 |
+
|
207 |
+
try:
|
208 |
+
if os.path.isfile(settings_file):
|
209 |
+
with open(settings_file, "r", encoding="utf8") as file:
|
210 |
+
settings = json.load(file)
|
211 |
+
except Exception:
|
212 |
+
errors.report("Could not load settings", exc_info=True)
|
213 |
+
|
214 |
+
disabled_extensions = set(settings.get('disabled_extensions', []))
|
215 |
+
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
216 |
+
|
217 |
+
if disable_all_extensions != 'none':
|
218 |
+
return []
|
219 |
+
|
220 |
+
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
221 |
+
|
222 |
+
|
223 |
+
def run_extensions_installers(settings_file):
|
224 |
+
if not os.path.isdir(extensions_dir):
|
225 |
+
return
|
226 |
+
|
227 |
+
with startup_timer.subcategory("run extensions installers"):
|
228 |
+
for dirname_extension in list_extensions(settings_file):
|
229 |
+
path = os.path.join(extensions_dir, dirname_extension)
|
230 |
+
|
231 |
+
if os.path.isdir(path):
|
232 |
+
run_extension_installer(path)
|
233 |
+
startup_timer.record(dirname_extension)
|
234 |
+
|
235 |
+
|
236 |
+
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
237 |
+
|
238 |
+
|
239 |
+
def requirements_met(requirements_file):
|
240 |
+
"""
|
241 |
+
Does a simple parse of a requirements.txt file to determine if all rerqirements in it
|
242 |
+
are already installed. Returns True if so, False if not installed or parsing fails.
|
243 |
+
"""
|
244 |
+
|
245 |
+
import importlib.metadata
|
246 |
+
import packaging.version
|
247 |
+
|
248 |
+
with open(requirements_file, "r", encoding="utf8") as file:
|
249 |
+
for line in file:
|
250 |
+
if line.strip() == "":
|
251 |
+
continue
|
252 |
+
|
253 |
+
m = re.match(re_requirement, line)
|
254 |
+
if m is None:
|
255 |
+
return False
|
256 |
+
|
257 |
+
package = m.group(1).strip()
|
258 |
+
version_required = (m.group(2) or "").strip()
|
259 |
+
|
260 |
+
if version_required == "":
|
261 |
+
continue
|
262 |
+
|
263 |
+
try:
|
264 |
+
version_installed = importlib.metadata.version(package)
|
265 |
+
except Exception:
|
266 |
+
return False
|
267 |
+
|
268 |
+
if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
|
269 |
+
return False
|
270 |
+
|
271 |
+
return True
|
272 |
+
|
273 |
+
|
274 |
+
def prepare_environment():
|
275 |
+
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
276 |
+
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
277 |
+
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
278 |
+
|
279 |
+
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
|
280 |
+
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
|
281 |
+
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
282 |
+
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
283 |
+
|
284 |
+
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
285 |
+
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
286 |
+
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
287 |
+
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
288 |
+
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
289 |
+
|
290 |
+
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
291 |
+
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
|
292 |
+
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
293 |
+
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
294 |
+
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
295 |
+
|
296 |
+
try:
|
297 |
+
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
298 |
+
os.remove(os.path.join(script_path, "tmp", "restart"))
|
299 |
+
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
300 |
+
except OSError:
|
301 |
+
pass
|
302 |
+
|
303 |
+
if not args.skip_python_version_check:
|
304 |
+
check_python_version()
|
305 |
+
|
306 |
+
startup_timer.record("checks")
|
307 |
+
|
308 |
+
commit = commit_hash()
|
309 |
+
tag = git_tag()
|
310 |
+
startup_timer.record("git version info")
|
311 |
+
|
312 |
+
print(f"Python {sys.version}")
|
313 |
+
print(f"Version: {tag}")
|
314 |
+
print(f"Commit hash: {commit}")
|
315 |
+
|
316 |
+
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
317 |
+
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
318 |
+
startup_timer.record("install torch")
|
319 |
+
|
320 |
+
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
321 |
+
raise RuntimeError(
|
322 |
+
'Torch is not able to use GPU; '
|
323 |
+
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
324 |
+
)
|
325 |
+
startup_timer.record("torch GPU test")
|
326 |
+
|
327 |
+
|
328 |
+
if not is_installed("gfpgan"):
|
329 |
+
run_pip(f"install {gfpgan_package}", "gfpgan")
|
330 |
+
startup_timer.record("install gfpgan")
|
331 |
+
|
332 |
+
if not is_installed("clip"):
|
333 |
+
run_pip(f"install {clip_package}", "clip")
|
334 |
+
startup_timer.record("install clip")
|
335 |
+
|
336 |
+
if not is_installed("open_clip"):
|
337 |
+
run_pip(f"install {openclip_package}", "open_clip")
|
338 |
+
startup_timer.record("install open_clip")
|
339 |
+
|
340 |
+
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
341 |
+
if platform.system() == "Windows":
|
342 |
+
if platform.python_version().startswith("3.10"):
|
343 |
+
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
|
344 |
+
else:
|
345 |
+
print("Installation of xformers is not supported in this version of Python.")
|
346 |
+
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
347 |
+
if not is_installed("xformers"):
|
348 |
+
exit(0)
|
349 |
+
elif platform.system() == "Linux":
|
350 |
+
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
351 |
+
|
352 |
+
startup_timer.record("install xformers")
|
353 |
+
|
354 |
+
if not is_installed("ngrok") and args.ngrok:
|
355 |
+
run_pip("install ngrok", "ngrok")
|
356 |
+
startup_timer.record("install ngrok")
|
357 |
+
|
358 |
+
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
359 |
+
|
360 |
+
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
361 |
+
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
362 |
+
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
363 |
+
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
364 |
+
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
365 |
+
|
366 |
+
startup_timer.record("clone repositores")
|
367 |
+
|
368 |
+
if not is_installed("lpips"):
|
369 |
+
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
370 |
+
startup_timer.record("install CodeFormer requirements")
|
371 |
+
|
372 |
+
if not os.path.isfile(requirements_file):
|
373 |
+
requirements_file = os.path.join(script_path, requirements_file)
|
374 |
+
|
375 |
+
if not requirements_met(requirements_file):
|
376 |
+
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
377 |
+
startup_timer.record("install requirements")
|
378 |
+
|
379 |
+
run_extensions_installers(settings_file=args.ui_settings_file)
|
380 |
+
|
381 |
+
if args.update_check:
|
382 |
+
version_check(commit)
|
383 |
+
startup_timer.record("check version")
|
384 |
+
|
385 |
+
if args.update_all_extensions:
|
386 |
+
git_pull_recursive(extensions_dir)
|
387 |
+
startup_timer.record("update extensions")
|
388 |
+
|
389 |
+
if "--exit" in sys.argv:
|
390 |
+
print("Exiting because of --exit argument")
|
391 |
+
exit(0)
|
392 |
+
|
393 |
+
|
394 |
+
|
395 |
+
def configure_for_tests():
|
396 |
+
if "--api" not in sys.argv:
|
397 |
+
sys.argv.append("--api")
|
398 |
+
if "--ckpt" not in sys.argv:
|
399 |
+
sys.argv.append("--ckpt")
|
400 |
+
sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
|
401 |
+
if "--skip-torch-cuda-test" not in sys.argv:
|
402 |
+
sys.argv.append("--skip-torch-cuda-test")
|
403 |
+
if "--disable-nan-check" not in sys.argv:
|
404 |
+
sys.argv.append("--disable-nan-check")
|
405 |
+
|
406 |
+
os.environ['COMMANDLINE_ARGS'] = ""
|
407 |
+
|
408 |
+
|
409 |
+
def start():
|
410 |
+
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
|
411 |
+
import webui
|
412 |
+
if '--nowebui' in sys.argv:
|
413 |
+
webui.api_only()
|
414 |
+
else:
|
415 |
+
webui.webui()
|
modules/localization.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
from modules import errors
|
5 |
+
|
6 |
+
localizations = {}
|
7 |
+
|
8 |
+
|
9 |
+
def list_localizations(dirname):
|
10 |
+
localizations.clear()
|
11 |
+
|
12 |
+
for file in os.listdir(dirname):
|
13 |
+
fn, ext = os.path.splitext(file)
|
14 |
+
if ext.lower() != ".json":
|
15 |
+
continue
|
16 |
+
|
17 |
+
localizations[fn] = os.path.join(dirname, file)
|
18 |
+
|
19 |
+
from modules import scripts
|
20 |
+
for file in scripts.list_scripts("localizations", ".json"):
|
21 |
+
fn, ext = os.path.splitext(file.filename)
|
22 |
+
localizations[fn] = file.path
|
23 |
+
|
24 |
+
|
25 |
+
def localization_js(current_localization_name: str) -> str:
|
26 |
+
fn = localizations.get(current_localization_name, None)
|
27 |
+
data = {}
|
28 |
+
if fn is not None:
|
29 |
+
try:
|
30 |
+
with open(fn, "r", encoding="utf8") as file:
|
31 |
+
data = json.load(file)
|
32 |
+
except Exception:
|
33 |
+
errors.report(f"Error loading localization from {fn}", exc_info=True)
|
34 |
+
|
35 |
+
return f"window.localization = {json.dumps(data)}"
|
modules/lowvram.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from modules import devices
|
3 |
+
|
4 |
+
module_in_gpu = None
|
5 |
+
cpu = torch.device("cpu")
|
6 |
+
|
7 |
+
|
8 |
+
def send_everything_to_cpu():
|
9 |
+
global module_in_gpu
|
10 |
+
|
11 |
+
if module_in_gpu is not None:
|
12 |
+
module_in_gpu.to(cpu)
|
13 |
+
|
14 |
+
module_in_gpu = None
|
15 |
+
|
16 |
+
|
17 |
+
def setup_for_low_vram(sd_model, use_medvram):
|
18 |
+
sd_model.lowvram = True
|
19 |
+
|
20 |
+
parents = {}
|
21 |
+
|
22 |
+
def send_me_to_gpu(module, _):
|
23 |
+
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
|
24 |
+
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
|
25 |
+
be in CPU
|
26 |
+
"""
|
27 |
+
global module_in_gpu
|
28 |
+
|
29 |
+
module = parents.get(module, module)
|
30 |
+
|
31 |
+
if module_in_gpu == module:
|
32 |
+
return
|
33 |
+
|
34 |
+
if module_in_gpu is not None:
|
35 |
+
module_in_gpu.to(cpu)
|
36 |
+
|
37 |
+
module.to(devices.device)
|
38 |
+
module_in_gpu = module
|
39 |
+
|
40 |
+
# see below for register_forward_pre_hook;
|
41 |
+
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
42 |
+
# useless here, and we just replace those methods
|
43 |
+
|
44 |
+
first_stage_model = sd_model.first_stage_model
|
45 |
+
first_stage_model_encode = sd_model.first_stage_model.encode
|
46 |
+
first_stage_model_decode = sd_model.first_stage_model.decode
|
47 |
+
|
48 |
+
def first_stage_model_encode_wrap(x):
|
49 |
+
send_me_to_gpu(first_stage_model, None)
|
50 |
+
return first_stage_model_encode(x)
|
51 |
+
|
52 |
+
def first_stage_model_decode_wrap(z):
|
53 |
+
send_me_to_gpu(first_stage_model, None)
|
54 |
+
return first_stage_model_decode(z)
|
55 |
+
|
56 |
+
to_remain_in_cpu = [
|
57 |
+
(sd_model, 'first_stage_model'),
|
58 |
+
(sd_model, 'depth_model'),
|
59 |
+
(sd_model, 'embedder'),
|
60 |
+
(sd_model, 'model'),
|
61 |
+
(sd_model, 'embedder'),
|
62 |
+
]
|
63 |
+
|
64 |
+
is_sdxl = hasattr(sd_model, 'conditioner')
|
65 |
+
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
|
66 |
+
|
67 |
+
if is_sdxl:
|
68 |
+
to_remain_in_cpu.append((sd_model, 'conditioner'))
|
69 |
+
elif is_sd2:
|
70 |
+
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
|
71 |
+
else:
|
72 |
+
to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
|
73 |
+
|
74 |
+
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
|
75 |
+
stored = []
|
76 |
+
for obj, field in to_remain_in_cpu:
|
77 |
+
module = getattr(obj, field, None)
|
78 |
+
stored.append(module)
|
79 |
+
setattr(obj, field, None)
|
80 |
+
|
81 |
+
# send the model to GPU.
|
82 |
+
sd_model.to(devices.device)
|
83 |
+
|
84 |
+
# put modules back. the modules will be in CPU.
|
85 |
+
for (obj, field), module in zip(to_remain_in_cpu, stored):
|
86 |
+
setattr(obj, field, module)
|
87 |
+
|
88 |
+
# register hooks for those the first three models
|
89 |
+
if is_sdxl:
|
90 |
+
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
|
91 |
+
elif is_sd2:
|
92 |
+
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
|
93 |
+
sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)
|
94 |
+
parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
|
95 |
+
parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model
|
96 |
+
else:
|
97 |
+
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
98 |
+
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
99 |
+
|
100 |
+
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
101 |
+
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
102 |
+
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
103 |
+
if sd_model.depth_model:
|
104 |
+
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
105 |
+
if sd_model.embedder:
|
106 |
+
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
107 |
+
|
108 |
+
if use_medvram:
|
109 |
+
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
110 |
+
else:
|
111 |
+
diff_model = sd_model.model.diffusion_model
|
112 |
+
|
113 |
+
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
|
114 |
+
# so that only one of them is in GPU at a time
|
115 |
+
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
116 |
+
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
117 |
+
sd_model.model.to(devices.device)
|
118 |
+
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
119 |
+
|
120 |
+
# install hooks for bits of third model
|
121 |
+
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
122 |
+
for block in diff_model.input_blocks:
|
123 |
+
block.register_forward_pre_hook(send_me_to_gpu)
|
124 |
+
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
|
125 |
+
for block in diff_model.output_blocks:
|
126 |
+
block.register_forward_pre_hook(send_me_to_gpu)
|
127 |
+
|
128 |
+
|
129 |
+
def is_enabled(sd_model):
|
130 |
+
return getattr(sd_model, 'lowvram', False)
|
modules/mac_specific.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import platform
|
5 |
+
from modules.sd_hijack_utils import CondFunc
|
6 |
+
from packaging import version
|
7 |
+
|
8 |
+
log = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
12 |
+
# use check `getattr` and try it for compatibility.
|
13 |
+
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
|
14 |
+
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
15 |
+
def check_for_mps() -> bool:
|
16 |
+
if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
17 |
+
if not getattr(torch, 'has_mps', False):
|
18 |
+
return False
|
19 |
+
try:
|
20 |
+
torch.zeros(1).to(torch.device("mps"))
|
21 |
+
return True
|
22 |
+
except Exception:
|
23 |
+
return False
|
24 |
+
else:
|
25 |
+
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
26 |
+
|
27 |
+
|
28 |
+
has_mps = check_for_mps()
|
29 |
+
|
30 |
+
|
31 |
+
def torch_mps_gc() -> None:
|
32 |
+
try:
|
33 |
+
from modules.shared import state
|
34 |
+
if state.current_latent is not None:
|
35 |
+
log.debug("`current_latent` is set, skipping MPS garbage collection")
|
36 |
+
return
|
37 |
+
from torch.mps import empty_cache
|
38 |
+
empty_cache()
|
39 |
+
except Exception:
|
40 |
+
log.warning("MPS garbage collection failed", exc_info=True)
|
41 |
+
|
42 |
+
|
43 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
44 |
+
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
45 |
+
if input.device.type == 'mps':
|
46 |
+
output_dtype = kwargs.get('dtype', input.dtype)
|
47 |
+
if output_dtype == torch.int64:
|
48 |
+
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
49 |
+
elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
50 |
+
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
51 |
+
return cumsum_func(input, *args, **kwargs)
|
52 |
+
|
53 |
+
|
54 |
+
if has_mps:
|
55 |
+
# MPS fix for randn in torchsde
|
56 |
+
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
57 |
+
|
58 |
+
if platform.mac_ver()[0].startswith("13.2."):
|
59 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
60 |
+
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
61 |
+
|
62 |
+
if version.parse(torch.__version__) < version.parse("1.13"):
|
63 |
+
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
64 |
+
|
65 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
66 |
+
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
67 |
+
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
68 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
69 |
+
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
70 |
+
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
71 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
72 |
+
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
73 |
+
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
74 |
+
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
75 |
+
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
76 |
+
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
77 |
+
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
78 |
+
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
79 |
+
|
80 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
81 |
+
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
82 |
+
|
83 |
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
84 |
+
if platform.processor() == 'i386':
|
85 |
+
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
86 |
+
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
modules/masking.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageFilter, ImageOps
|
2 |
+
|
3 |
+
|
4 |
+
def get_crop_region(mask, pad=0):
|
5 |
+
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
6 |
+
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
7 |
+
|
8 |
+
h, w = mask.shape
|
9 |
+
|
10 |
+
crop_left = 0
|
11 |
+
for i in range(w):
|
12 |
+
if not (mask[:, i] == 0).all():
|
13 |
+
break
|
14 |
+
crop_left += 1
|
15 |
+
|
16 |
+
crop_right = 0
|
17 |
+
for i in reversed(range(w)):
|
18 |
+
if not (mask[:, i] == 0).all():
|
19 |
+
break
|
20 |
+
crop_right += 1
|
21 |
+
|
22 |
+
crop_top = 0
|
23 |
+
for i in range(h):
|
24 |
+
if not (mask[i] == 0).all():
|
25 |
+
break
|
26 |
+
crop_top += 1
|
27 |
+
|
28 |
+
crop_bottom = 0
|
29 |
+
for i in reversed(range(h)):
|
30 |
+
if not (mask[i] == 0).all():
|
31 |
+
break
|
32 |
+
crop_bottom += 1
|
33 |
+
|
34 |
+
return (
|
35 |
+
int(max(crop_left-pad, 0)),
|
36 |
+
int(max(crop_top-pad, 0)),
|
37 |
+
int(min(w - crop_right + pad, w)),
|
38 |
+
int(min(h - crop_bottom + pad, h))
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
|
43 |
+
"""expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
|
44 |
+
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
|
45 |
+
|
46 |
+
x1, y1, x2, y2 = crop_region
|
47 |
+
|
48 |
+
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
49 |
+
ratio_processing = processing_width / processing_height
|
50 |
+
|
51 |
+
if ratio_crop_region > ratio_processing:
|
52 |
+
desired_height = (x2 - x1) / ratio_processing
|
53 |
+
desired_height_diff = int(desired_height - (y2-y1))
|
54 |
+
y1 -= desired_height_diff//2
|
55 |
+
y2 += desired_height_diff - desired_height_diff//2
|
56 |
+
if y2 >= image_height:
|
57 |
+
diff = y2 - image_height
|
58 |
+
y2 -= diff
|
59 |
+
y1 -= diff
|
60 |
+
if y1 < 0:
|
61 |
+
y2 -= y1
|
62 |
+
y1 -= y1
|
63 |
+
if y2 >= image_height:
|
64 |
+
y2 = image_height
|
65 |
+
else:
|
66 |
+
desired_width = (y2 - y1) * ratio_processing
|
67 |
+
desired_width_diff = int(desired_width - (x2-x1))
|
68 |
+
x1 -= desired_width_diff//2
|
69 |
+
x2 += desired_width_diff - desired_width_diff//2
|
70 |
+
if x2 >= image_width:
|
71 |
+
diff = x2 - image_width
|
72 |
+
x2 -= diff
|
73 |
+
x1 -= diff
|
74 |
+
if x1 < 0:
|
75 |
+
x2 -= x1
|
76 |
+
x1 -= x1
|
77 |
+
if x2 >= image_width:
|
78 |
+
x2 = image_width
|
79 |
+
|
80 |
+
return x1, y1, x2, y2
|
81 |
+
|
82 |
+
|
83 |
+
def fill(image, mask):
|
84 |
+
"""fills masked regions with colors from image using blur. Not extremely effective."""
|
85 |
+
|
86 |
+
image_mod = Image.new('RGBA', (image.width, image.height))
|
87 |
+
|
88 |
+
image_masked = Image.new('RGBa', (image.width, image.height))
|
89 |
+
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
|
90 |
+
|
91 |
+
image_masked = image_masked.convert('RGBa')
|
92 |
+
|
93 |
+
for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
|
94 |
+
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
|
95 |
+
for _ in range(repeats):
|
96 |
+
image_mod.alpha_composite(blurred)
|
97 |
+
|
98 |
+
return image_mod.convert("RGB")
|
99 |
+
|
modules/memmon.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import time
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class MemUsageMonitor(threading.Thread):
|
9 |
+
run_flag = None
|
10 |
+
device = None
|
11 |
+
disabled = False
|
12 |
+
opts = None
|
13 |
+
data = None
|
14 |
+
|
15 |
+
def __init__(self, name, device, opts):
|
16 |
+
threading.Thread.__init__(self)
|
17 |
+
self.name = name
|
18 |
+
self.device = device
|
19 |
+
self.opts = opts
|
20 |
+
|
21 |
+
self.daemon = True
|
22 |
+
self.run_flag = threading.Event()
|
23 |
+
self.data = defaultdict(int)
|
24 |
+
|
25 |
+
try:
|
26 |
+
self.cuda_mem_get_info()
|
27 |
+
torch.cuda.memory_stats(self.device)
|
28 |
+
except Exception as e: # AMD or whatever
|
29 |
+
print(f"Warning: caught exception '{e}', memory monitor disabled")
|
30 |
+
self.disabled = True
|
31 |
+
|
32 |
+
def cuda_mem_get_info(self):
|
33 |
+
index = self.device.index if self.device.index is not None else torch.cuda.current_device()
|
34 |
+
return torch.cuda.mem_get_info(index)
|
35 |
+
|
36 |
+
def run(self):
|
37 |
+
if self.disabled:
|
38 |
+
return
|
39 |
+
|
40 |
+
while True:
|
41 |
+
self.run_flag.wait()
|
42 |
+
|
43 |
+
torch.cuda.reset_peak_memory_stats()
|
44 |
+
self.data.clear()
|
45 |
+
|
46 |
+
if self.opts.memmon_poll_rate <= 0:
|
47 |
+
self.run_flag.clear()
|
48 |
+
continue
|
49 |
+
|
50 |
+
self.data["min_free"] = self.cuda_mem_get_info()[0]
|
51 |
+
|
52 |
+
while self.run_flag.is_set():
|
53 |
+
free, total = self.cuda_mem_get_info()
|
54 |
+
self.data["min_free"] = min(self.data["min_free"], free)
|
55 |
+
|
56 |
+
time.sleep(1 / self.opts.memmon_poll_rate)
|
57 |
+
|
58 |
+
def dump_debug(self):
|
59 |
+
print(self, 'recorded data:')
|
60 |
+
for k, v in self.read().items():
|
61 |
+
print(k, -(v // -(1024 ** 2)))
|
62 |
+
|
63 |
+
print(self, 'raw torch memory stats:')
|
64 |
+
tm = torch.cuda.memory_stats(self.device)
|
65 |
+
for k, v in tm.items():
|
66 |
+
if 'bytes' not in k:
|
67 |
+
continue
|
68 |
+
print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
|
69 |
+
|
70 |
+
print(torch.cuda.memory_summary())
|
71 |
+
|
72 |
+
def monitor(self):
|
73 |
+
self.run_flag.set()
|
74 |
+
|
75 |
+
def read(self):
|
76 |
+
if not self.disabled:
|
77 |
+
free, total = self.cuda_mem_get_info()
|
78 |
+
self.data["free"] = free
|
79 |
+
self.data["total"] = total
|
80 |
+
|
81 |
+
torch_stats = torch.cuda.memory_stats(self.device)
|
82 |
+
self.data["active"] = torch_stats["active.all.current"]
|
83 |
+
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
84 |
+
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
|
85 |
+
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
86 |
+
self.data["system_peak"] = total - self.data["min_free"]
|
87 |
+
|
88 |
+
return self.data
|
89 |
+
|
90 |
+
def stop(self):
|
91 |
+
self.run_flag.clear()
|
92 |
+
return self.read()
|
modules/modelloader.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import importlib
|
6 |
+
from urllib.parse import urlparse
|
7 |
+
|
8 |
+
from modules import shared
|
9 |
+
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
10 |
+
from modules.paths import script_path, models_path
|
11 |
+
|
12 |
+
|
13 |
+
def load_file_from_url(
|
14 |
+
url: str,
|
15 |
+
*,
|
16 |
+
model_dir: str,
|
17 |
+
progress: bool = True,
|
18 |
+
file_name: str | None = None,
|
19 |
+
) -> str:
|
20 |
+
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
21 |
+
|
22 |
+
Returns the path to the downloaded file.
|
23 |
+
"""
|
24 |
+
os.makedirs(model_dir, exist_ok=True)
|
25 |
+
if not file_name:
|
26 |
+
parts = urlparse(url)
|
27 |
+
file_name = os.path.basename(parts.path)
|
28 |
+
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
29 |
+
if not os.path.exists(cached_file):
|
30 |
+
print(f'Downloading: "{url}" to {cached_file}\n')
|
31 |
+
from torch.hub import download_url_to_file
|
32 |
+
download_url_to_file(url, cached_file, progress=progress)
|
33 |
+
return cached_file
|
34 |
+
|
35 |
+
|
36 |
+
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
|
37 |
+
"""
|
38 |
+
A one-and done loader to try finding the desired models in specified directories.
|
39 |
+
|
40 |
+
@param download_name: Specify to download from model_url immediately.
|
41 |
+
@param model_url: If no other models are found, this will be downloaded on upscale.
|
42 |
+
@param model_path: The location to store/find models in.
|
43 |
+
@param command_path: A command-line argument to search for models in first.
|
44 |
+
@param ext_filter: An optional list of filename extensions to filter by
|
45 |
+
@return: A list of paths containing the desired model(s)
|
46 |
+
"""
|
47 |
+
output = []
|
48 |
+
|
49 |
+
try:
|
50 |
+
places = []
|
51 |
+
|
52 |
+
if command_path is not None and command_path != model_path:
|
53 |
+
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
|
54 |
+
if os.path.exists(pretrained_path):
|
55 |
+
print(f"Appending path: {pretrained_path}")
|
56 |
+
places.append(pretrained_path)
|
57 |
+
elif os.path.exists(command_path):
|
58 |
+
places.append(command_path)
|
59 |
+
|
60 |
+
places.append(model_path)
|
61 |
+
|
62 |
+
for place in places:
|
63 |
+
for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
|
64 |
+
if os.path.islink(full_path) and not os.path.exists(full_path):
|
65 |
+
print(f"Skipping broken symlink: {full_path}")
|
66 |
+
continue
|
67 |
+
if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
|
68 |
+
continue
|
69 |
+
if full_path not in output:
|
70 |
+
output.append(full_path)
|
71 |
+
|
72 |
+
if model_url is not None and len(output) == 0:
|
73 |
+
if download_name is not None:
|
74 |
+
output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
|
75 |
+
else:
|
76 |
+
output.append(model_url)
|
77 |
+
|
78 |
+
except Exception:
|
79 |
+
pass
|
80 |
+
|
81 |
+
return output
|
82 |
+
|
83 |
+
|
84 |
+
def friendly_name(file: str):
|
85 |
+
if file.startswith("http"):
|
86 |
+
file = urlparse(file).path
|
87 |
+
|
88 |
+
file = os.path.basename(file)
|
89 |
+
model_name, extension = os.path.splitext(file)
|
90 |
+
return model_name
|
91 |
+
|
92 |
+
|
93 |
+
def cleanup_models():
|
94 |
+
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
|
95 |
+
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
|
96 |
+
# somehow auto-register and just do these things...
|
97 |
+
root_path = script_path
|
98 |
+
src_path = models_path
|
99 |
+
dest_path = os.path.join(models_path, "Stable-diffusion")
|
100 |
+
move_files(src_path, dest_path, ".ckpt")
|
101 |
+
move_files(src_path, dest_path, ".safetensors")
|
102 |
+
src_path = os.path.join(root_path, "ESRGAN")
|
103 |
+
dest_path = os.path.join(models_path, "ESRGAN")
|
104 |
+
move_files(src_path, dest_path)
|
105 |
+
src_path = os.path.join(models_path, "BSRGAN")
|
106 |
+
dest_path = os.path.join(models_path, "ESRGAN")
|
107 |
+
move_files(src_path, dest_path, ".pth")
|
108 |
+
src_path = os.path.join(root_path, "gfpgan")
|
109 |
+
dest_path = os.path.join(models_path, "GFPGAN")
|
110 |
+
move_files(src_path, dest_path)
|
111 |
+
src_path = os.path.join(root_path, "SwinIR")
|
112 |
+
dest_path = os.path.join(models_path, "SwinIR")
|
113 |
+
move_files(src_path, dest_path)
|
114 |
+
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
|
115 |
+
dest_path = os.path.join(models_path, "LDSR")
|
116 |
+
move_files(src_path, dest_path)
|
117 |
+
|
118 |
+
|
119 |
+
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
120 |
+
try:
|
121 |
+
os.makedirs(dest_path, exist_ok=True)
|
122 |
+
if os.path.exists(src_path):
|
123 |
+
for file in os.listdir(src_path):
|
124 |
+
fullpath = os.path.join(src_path, file)
|
125 |
+
if os.path.isfile(fullpath):
|
126 |
+
if ext_filter is not None:
|
127 |
+
if ext_filter not in file:
|
128 |
+
continue
|
129 |
+
print(f"Moving {file} from {src_path} to {dest_path}.")
|
130 |
+
try:
|
131 |
+
shutil.move(fullpath, dest_path)
|
132 |
+
except Exception:
|
133 |
+
pass
|
134 |
+
if len(os.listdir(src_path)) == 0:
|
135 |
+
print(f"Removing empty folder: {src_path}")
|
136 |
+
shutil.rmtree(src_path, True)
|
137 |
+
except Exception:
|
138 |
+
pass
|
139 |
+
|
140 |
+
|
141 |
+
def load_upscalers():
|
142 |
+
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
143 |
+
# so we'll try to import any _model.py files before looking in __subclasses__
|
144 |
+
modules_dir = os.path.join(shared.script_path, "modules")
|
145 |
+
for file in os.listdir(modules_dir):
|
146 |
+
if "_model.py" in file:
|
147 |
+
model_name = file.replace("_model.py", "")
|
148 |
+
full_model = f"modules.{model_name}_model"
|
149 |
+
try:
|
150 |
+
importlib.import_module(full_model)
|
151 |
+
except Exception:
|
152 |
+
pass
|
153 |
+
|
154 |
+
datas = []
|
155 |
+
commandline_options = vars(shared.cmd_opts)
|
156 |
+
|
157 |
+
# some of upscaler classes will not go away after reloading their modules, and we'll end
|
158 |
+
# up with two copies of those classes. The newest copy will always be the last in the list,
|
159 |
+
# so we go from end to beginning and ignore duplicates
|
160 |
+
used_classes = {}
|
161 |
+
for cls in reversed(Upscaler.__subclasses__()):
|
162 |
+
classname = str(cls)
|
163 |
+
if classname not in used_classes:
|
164 |
+
used_classes[classname] = cls
|
165 |
+
|
166 |
+
for cls in reversed(used_classes.values()):
|
167 |
+
name = cls.__name__
|
168 |
+
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
169 |
+
commandline_model_path = commandline_options.get(cmd_name, None)
|
170 |
+
scaler = cls(commandline_model_path)
|
171 |
+
scaler.user_path = commandline_model_path
|
172 |
+
scaler.model_download_path = commandline_model_path or scaler.model_path
|
173 |
+
datas += scaler.scalers
|
174 |
+
|
175 |
+
shared.sd_upscalers = sorted(
|
176 |
+
datas,
|
177 |
+
# Special case for UpscalerNone keeps it at the beginning of the list.
|
178 |
+
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
|
179 |
+
)
|
modules/models/diffusion/ddpm_edit.py
ADDED
@@ -0,0 +1,1455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
wild mixture of
|
3 |
+
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
4 |
+
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
|
5 |
+
https://github.com/CompVis/taming-transformers
|
6 |
+
-- merci
|
7 |
+
"""
|
8 |
+
|
9 |
+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
10 |
+
# See more details in LICENSE.
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import numpy as np
|
15 |
+
import pytorch_lightning as pl
|
16 |
+
from torch.optim.lr_scheduler import LambdaLR
|
17 |
+
from einops import rearrange, repeat
|
18 |
+
from contextlib import contextmanager
|
19 |
+
from functools import partial
|
20 |
+
from tqdm import tqdm
|
21 |
+
from torchvision.utils import make_grid
|
22 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
23 |
+
|
24 |
+
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
25 |
+
from ldm.modules.ema import LitEma
|
26 |
+
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
27 |
+
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
|
28 |
+
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
29 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
30 |
+
|
31 |
+
|
32 |
+
__conditioning_keys__ = {'concat': 'c_concat',
|
33 |
+
'crossattn': 'c_crossattn',
|
34 |
+
'adm': 'y'}
|
35 |
+
|
36 |
+
|
37 |
+
def disabled_train(self, mode=True):
|
38 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
39 |
+
does not change anymore."""
|
40 |
+
return self
|
41 |
+
|
42 |
+
|
43 |
+
def uniform_on_device(r1, r2, shape, device):
|
44 |
+
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
45 |
+
|
46 |
+
|
47 |
+
class DDPM(pl.LightningModule):
|
48 |
+
# classic DDPM with Gaussian diffusion, in image space
|
49 |
+
def __init__(self,
|
50 |
+
unet_config,
|
51 |
+
timesteps=1000,
|
52 |
+
beta_schedule="linear",
|
53 |
+
loss_type="l2",
|
54 |
+
ckpt_path=None,
|
55 |
+
ignore_keys=None,
|
56 |
+
load_only_unet=False,
|
57 |
+
monitor="val/loss",
|
58 |
+
use_ema=True,
|
59 |
+
first_stage_key="image",
|
60 |
+
image_size=256,
|
61 |
+
channels=3,
|
62 |
+
log_every_t=100,
|
63 |
+
clip_denoised=True,
|
64 |
+
linear_start=1e-4,
|
65 |
+
linear_end=2e-2,
|
66 |
+
cosine_s=8e-3,
|
67 |
+
given_betas=None,
|
68 |
+
original_elbo_weight=0.,
|
69 |
+
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
70 |
+
l_simple_weight=1.,
|
71 |
+
conditioning_key=None,
|
72 |
+
parameterization="eps", # all assuming fixed variance schedules
|
73 |
+
scheduler_config=None,
|
74 |
+
use_positional_encodings=False,
|
75 |
+
learn_logvar=False,
|
76 |
+
logvar_init=0.,
|
77 |
+
load_ema=True,
|
78 |
+
):
|
79 |
+
super().__init__()
|
80 |
+
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
|
81 |
+
self.parameterization = parameterization
|
82 |
+
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
|
83 |
+
self.cond_stage_model = None
|
84 |
+
self.clip_denoised = clip_denoised
|
85 |
+
self.log_every_t = log_every_t
|
86 |
+
self.first_stage_key = first_stage_key
|
87 |
+
self.image_size = image_size # try conv?
|
88 |
+
self.channels = channels
|
89 |
+
self.use_positional_encodings = use_positional_encodings
|
90 |
+
self.model = DiffusionWrapper(unet_config, conditioning_key)
|
91 |
+
count_params(self.model, verbose=True)
|
92 |
+
self.use_ema = use_ema
|
93 |
+
|
94 |
+
self.use_scheduler = scheduler_config is not None
|
95 |
+
if self.use_scheduler:
|
96 |
+
self.scheduler_config = scheduler_config
|
97 |
+
|
98 |
+
self.v_posterior = v_posterior
|
99 |
+
self.original_elbo_weight = original_elbo_weight
|
100 |
+
self.l_simple_weight = l_simple_weight
|
101 |
+
|
102 |
+
if monitor is not None:
|
103 |
+
self.monitor = monitor
|
104 |
+
|
105 |
+
if self.use_ema and load_ema:
|
106 |
+
self.model_ema = LitEma(self.model)
|
107 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
108 |
+
|
109 |
+
if ckpt_path is not None:
|
110 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
111 |
+
|
112 |
+
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
113 |
+
if self.use_ema and not load_ema:
|
114 |
+
self.model_ema = LitEma(self.model)
|
115 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
116 |
+
|
117 |
+
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
118 |
+
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
119 |
+
|
120 |
+
self.loss_type = loss_type
|
121 |
+
|
122 |
+
self.learn_logvar = learn_logvar
|
123 |
+
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
|
124 |
+
if self.learn_logvar:
|
125 |
+
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
126 |
+
|
127 |
+
|
128 |
+
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
129 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
130 |
+
if exists(given_betas):
|
131 |
+
betas = given_betas
|
132 |
+
else:
|
133 |
+
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
134 |
+
cosine_s=cosine_s)
|
135 |
+
alphas = 1. - betas
|
136 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
137 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
138 |
+
|
139 |
+
timesteps, = betas.shape
|
140 |
+
self.num_timesteps = int(timesteps)
|
141 |
+
self.linear_start = linear_start
|
142 |
+
self.linear_end = linear_end
|
143 |
+
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
144 |
+
|
145 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
146 |
+
|
147 |
+
self.register_buffer('betas', to_torch(betas))
|
148 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
149 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
150 |
+
|
151 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
152 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
153 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
154 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
155 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
156 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
157 |
+
|
158 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
159 |
+
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
|
160 |
+
1. - alphas_cumprod) + self.v_posterior * betas
|
161 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
162 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
163 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
164 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
165 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
166 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
167 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
168 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
169 |
+
|
170 |
+
if self.parameterization == "eps":
|
171 |
+
lvlb_weights = self.betas ** 2 / (
|
172 |
+
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
173 |
+
elif self.parameterization == "x0":
|
174 |
+
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
|
175 |
+
else:
|
176 |
+
raise NotImplementedError("mu not supported")
|
177 |
+
# TODO how to choose this term
|
178 |
+
lvlb_weights[0] = lvlb_weights[1]
|
179 |
+
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
|
180 |
+
assert not torch.isnan(self.lvlb_weights).all()
|
181 |
+
|
182 |
+
@contextmanager
|
183 |
+
def ema_scope(self, context=None):
|
184 |
+
if self.use_ema:
|
185 |
+
self.model_ema.store(self.model.parameters())
|
186 |
+
self.model_ema.copy_to(self.model)
|
187 |
+
if context is not None:
|
188 |
+
print(f"{context}: Switched to EMA weights")
|
189 |
+
try:
|
190 |
+
yield None
|
191 |
+
finally:
|
192 |
+
if self.use_ema:
|
193 |
+
self.model_ema.restore(self.model.parameters())
|
194 |
+
if context is not None:
|
195 |
+
print(f"{context}: Restored training weights")
|
196 |
+
|
197 |
+
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
198 |
+
ignore_keys = ignore_keys or []
|
199 |
+
|
200 |
+
sd = torch.load(path, map_location="cpu")
|
201 |
+
if "state_dict" in list(sd.keys()):
|
202 |
+
sd = sd["state_dict"]
|
203 |
+
keys = list(sd.keys())
|
204 |
+
|
205 |
+
# Our model adds additional channels to the first layer to condition on an input image.
|
206 |
+
# For the first layer, copy existing channel weights and initialize new channel weights to zero.
|
207 |
+
input_keys = [
|
208 |
+
"model.diffusion_model.input_blocks.0.0.weight",
|
209 |
+
"model_ema.diffusion_modelinput_blocks00weight",
|
210 |
+
]
|
211 |
+
|
212 |
+
self_sd = self.state_dict()
|
213 |
+
for input_key in input_keys:
|
214 |
+
if input_key not in sd or input_key not in self_sd:
|
215 |
+
continue
|
216 |
+
|
217 |
+
input_weight = self_sd[input_key]
|
218 |
+
|
219 |
+
if input_weight.size() != sd[input_key].size():
|
220 |
+
print(f"Manual init: {input_key}")
|
221 |
+
input_weight.zero_()
|
222 |
+
input_weight[:, :4, :, :].copy_(sd[input_key])
|
223 |
+
ignore_keys.append(input_key)
|
224 |
+
|
225 |
+
for k in keys:
|
226 |
+
for ik in ignore_keys:
|
227 |
+
if k.startswith(ik):
|
228 |
+
print(f"Deleting key {k} from state_dict.")
|
229 |
+
del sd[k]
|
230 |
+
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
231 |
+
sd, strict=False)
|
232 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
233 |
+
if missing:
|
234 |
+
print(f"Missing Keys: {missing}")
|
235 |
+
if unexpected:
|
236 |
+
print(f"Unexpected Keys: {unexpected}")
|
237 |
+
|
238 |
+
def q_mean_variance(self, x_start, t):
|
239 |
+
"""
|
240 |
+
Get the distribution q(x_t | x_0).
|
241 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
242 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
243 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
244 |
+
"""
|
245 |
+
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
|
246 |
+
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
247 |
+
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
248 |
+
return mean, variance, log_variance
|
249 |
+
|
250 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
251 |
+
return (
|
252 |
+
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
253 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
254 |
+
)
|
255 |
+
|
256 |
+
def q_posterior(self, x_start, x_t, t):
|
257 |
+
posterior_mean = (
|
258 |
+
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
259 |
+
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
260 |
+
)
|
261 |
+
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
262 |
+
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
|
263 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
264 |
+
|
265 |
+
def p_mean_variance(self, x, t, clip_denoised: bool):
|
266 |
+
model_out = self.model(x, t)
|
267 |
+
if self.parameterization == "eps":
|
268 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
269 |
+
elif self.parameterization == "x0":
|
270 |
+
x_recon = model_out
|
271 |
+
if clip_denoised:
|
272 |
+
x_recon.clamp_(-1., 1.)
|
273 |
+
|
274 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
275 |
+
return model_mean, posterior_variance, posterior_log_variance
|
276 |
+
|
277 |
+
@torch.no_grad()
|
278 |
+
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
|
279 |
+
b, *_, device = *x.shape, x.device
|
280 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
|
281 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
282 |
+
# no noise when t == 0
|
283 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
284 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
285 |
+
|
286 |
+
@torch.no_grad()
|
287 |
+
def p_sample_loop(self, shape, return_intermediates=False):
|
288 |
+
device = self.betas.device
|
289 |
+
b = shape[0]
|
290 |
+
img = torch.randn(shape, device=device)
|
291 |
+
intermediates = [img]
|
292 |
+
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
|
293 |
+
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
|
294 |
+
clip_denoised=self.clip_denoised)
|
295 |
+
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
|
296 |
+
intermediates.append(img)
|
297 |
+
if return_intermediates:
|
298 |
+
return img, intermediates
|
299 |
+
return img
|
300 |
+
|
301 |
+
@torch.no_grad()
|
302 |
+
def sample(self, batch_size=16, return_intermediates=False):
|
303 |
+
image_size = self.image_size
|
304 |
+
channels = self.channels
|
305 |
+
return self.p_sample_loop((batch_size, channels, image_size, image_size),
|
306 |
+
return_intermediates=return_intermediates)
|
307 |
+
|
308 |
+
def q_sample(self, x_start, t, noise=None):
|
309 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
310 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
311 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
312 |
+
|
313 |
+
def get_loss(self, pred, target, mean=True):
|
314 |
+
if self.loss_type == 'l1':
|
315 |
+
loss = (target - pred).abs()
|
316 |
+
if mean:
|
317 |
+
loss = loss.mean()
|
318 |
+
elif self.loss_type == 'l2':
|
319 |
+
if mean:
|
320 |
+
loss = torch.nn.functional.mse_loss(target, pred)
|
321 |
+
else:
|
322 |
+
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
323 |
+
else:
|
324 |
+
raise NotImplementedError("unknown loss type '{loss_type}'")
|
325 |
+
|
326 |
+
return loss
|
327 |
+
|
328 |
+
def p_losses(self, x_start, t, noise=None):
|
329 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
330 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
331 |
+
model_out = self.model(x_noisy, t)
|
332 |
+
|
333 |
+
loss_dict = {}
|
334 |
+
if self.parameterization == "eps":
|
335 |
+
target = noise
|
336 |
+
elif self.parameterization == "x0":
|
337 |
+
target = x_start
|
338 |
+
else:
|
339 |
+
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
|
340 |
+
|
341 |
+
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
342 |
+
|
343 |
+
log_prefix = 'train' if self.training else 'val'
|
344 |
+
|
345 |
+
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
|
346 |
+
loss_simple = loss.mean() * self.l_simple_weight
|
347 |
+
|
348 |
+
loss_vlb = (self.lvlb_weights[t] * loss).mean()
|
349 |
+
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
|
350 |
+
|
351 |
+
loss = loss_simple + self.original_elbo_weight * loss_vlb
|
352 |
+
|
353 |
+
loss_dict.update({f'{log_prefix}/loss': loss})
|
354 |
+
|
355 |
+
return loss, loss_dict
|
356 |
+
|
357 |
+
def forward(self, x, *args, **kwargs):
|
358 |
+
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
|
359 |
+
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
|
360 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
361 |
+
return self.p_losses(x, t, *args, **kwargs)
|
362 |
+
|
363 |
+
def get_input(self, batch, k):
|
364 |
+
return batch[k]
|
365 |
+
|
366 |
+
def shared_step(self, batch):
|
367 |
+
x = self.get_input(batch, self.first_stage_key)
|
368 |
+
loss, loss_dict = self(x)
|
369 |
+
return loss, loss_dict
|
370 |
+
|
371 |
+
def training_step(self, batch, batch_idx):
|
372 |
+
loss, loss_dict = self.shared_step(batch)
|
373 |
+
|
374 |
+
self.log_dict(loss_dict, prog_bar=True,
|
375 |
+
logger=True, on_step=True, on_epoch=True)
|
376 |
+
|
377 |
+
self.log("global_step", self.global_step,
|
378 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
379 |
+
|
380 |
+
if self.use_scheduler:
|
381 |
+
lr = self.optimizers().param_groups[0]['lr']
|
382 |
+
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
383 |
+
|
384 |
+
return loss
|
385 |
+
|
386 |
+
@torch.no_grad()
|
387 |
+
def validation_step(self, batch, batch_idx):
|
388 |
+
_, loss_dict_no_ema = self.shared_step(batch)
|
389 |
+
with self.ema_scope():
|
390 |
+
_, loss_dict_ema = self.shared_step(batch)
|
391 |
+
loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}
|
392 |
+
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
393 |
+
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
394 |
+
|
395 |
+
def on_train_batch_end(self, *args, **kwargs):
|
396 |
+
if self.use_ema:
|
397 |
+
self.model_ema(self.model)
|
398 |
+
|
399 |
+
def _get_rows_from_list(self, samples):
|
400 |
+
n_imgs_per_row = len(samples)
|
401 |
+
denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
|
402 |
+
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
|
403 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
404 |
+
return denoise_grid
|
405 |
+
|
406 |
+
@torch.no_grad()
|
407 |
+
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
408 |
+
log = {}
|
409 |
+
x = self.get_input(batch, self.first_stage_key)
|
410 |
+
N = min(x.shape[0], N)
|
411 |
+
n_row = min(x.shape[0], n_row)
|
412 |
+
x = x.to(self.device)[:N]
|
413 |
+
log["inputs"] = x
|
414 |
+
|
415 |
+
# get diffusion row
|
416 |
+
diffusion_row = []
|
417 |
+
x_start = x[:n_row]
|
418 |
+
|
419 |
+
for t in range(self.num_timesteps):
|
420 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
421 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
422 |
+
t = t.to(self.device).long()
|
423 |
+
noise = torch.randn_like(x_start)
|
424 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
425 |
+
diffusion_row.append(x_noisy)
|
426 |
+
|
427 |
+
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
|
428 |
+
|
429 |
+
if sample:
|
430 |
+
# get denoise row
|
431 |
+
with self.ema_scope("Plotting"):
|
432 |
+
samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
|
433 |
+
|
434 |
+
log["samples"] = samples
|
435 |
+
log["denoise_row"] = self._get_rows_from_list(denoise_row)
|
436 |
+
|
437 |
+
if return_keys:
|
438 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
439 |
+
return log
|
440 |
+
else:
|
441 |
+
return {key: log[key] for key in return_keys}
|
442 |
+
return log
|
443 |
+
|
444 |
+
def configure_optimizers(self):
|
445 |
+
lr = self.learning_rate
|
446 |
+
params = list(self.model.parameters())
|
447 |
+
if self.learn_logvar:
|
448 |
+
params = params + [self.logvar]
|
449 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
450 |
+
return opt
|
451 |
+
|
452 |
+
|
453 |
+
class LatentDiffusion(DDPM):
|
454 |
+
"""main class"""
|
455 |
+
def __init__(self,
|
456 |
+
first_stage_config,
|
457 |
+
cond_stage_config,
|
458 |
+
num_timesteps_cond=None,
|
459 |
+
cond_stage_key="image",
|
460 |
+
cond_stage_trainable=False,
|
461 |
+
concat_mode=True,
|
462 |
+
cond_stage_forward=None,
|
463 |
+
conditioning_key=None,
|
464 |
+
scale_factor=1.0,
|
465 |
+
scale_by_std=False,
|
466 |
+
load_ema=True,
|
467 |
+
*args, **kwargs):
|
468 |
+
self.num_timesteps_cond = default(num_timesteps_cond, 1)
|
469 |
+
self.scale_by_std = scale_by_std
|
470 |
+
assert self.num_timesteps_cond <= kwargs['timesteps']
|
471 |
+
# for backwards compatibility after implementation of DiffusionWrapper
|
472 |
+
if conditioning_key is None:
|
473 |
+
conditioning_key = 'concat' if concat_mode else 'crossattn'
|
474 |
+
if cond_stage_config == '__is_unconditional__':
|
475 |
+
conditioning_key = None
|
476 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
477 |
+
ignore_keys = kwargs.pop("ignore_keys", [])
|
478 |
+
super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
|
479 |
+
self.concat_mode = concat_mode
|
480 |
+
self.cond_stage_trainable = cond_stage_trainable
|
481 |
+
self.cond_stage_key = cond_stage_key
|
482 |
+
try:
|
483 |
+
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
484 |
+
except Exception:
|
485 |
+
self.num_downs = 0
|
486 |
+
if not scale_by_std:
|
487 |
+
self.scale_factor = scale_factor
|
488 |
+
else:
|
489 |
+
self.register_buffer('scale_factor', torch.tensor(scale_factor))
|
490 |
+
self.instantiate_first_stage(first_stage_config)
|
491 |
+
self.instantiate_cond_stage(cond_stage_config)
|
492 |
+
self.cond_stage_forward = cond_stage_forward
|
493 |
+
self.clip_denoised = False
|
494 |
+
self.bbox_tokenizer = None
|
495 |
+
|
496 |
+
self.restarted_from_ckpt = False
|
497 |
+
if ckpt_path is not None:
|
498 |
+
self.init_from_ckpt(ckpt_path, ignore_keys)
|
499 |
+
self.restarted_from_ckpt = True
|
500 |
+
|
501 |
+
if self.use_ema and not load_ema:
|
502 |
+
self.model_ema = LitEma(self.model)
|
503 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
504 |
+
|
505 |
+
def make_cond_schedule(self, ):
|
506 |
+
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
507 |
+
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
508 |
+
self.cond_ids[:self.num_timesteps_cond] = ids
|
509 |
+
|
510 |
+
@rank_zero_only
|
511 |
+
@torch.no_grad()
|
512 |
+
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
513 |
+
# only for very first batch
|
514 |
+
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
|
515 |
+
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
|
516 |
+
# set rescale weight to 1./std of encodings
|
517 |
+
print("### USING STD-RESCALING ###")
|
518 |
+
x = super().get_input(batch, self.first_stage_key)
|
519 |
+
x = x.to(self.device)
|
520 |
+
encoder_posterior = self.encode_first_stage(x)
|
521 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
522 |
+
del self.scale_factor
|
523 |
+
self.register_buffer('scale_factor', 1. / z.flatten().std())
|
524 |
+
print(f"setting self.scale_factor to {self.scale_factor}")
|
525 |
+
print("### USING STD-RESCALING ###")
|
526 |
+
|
527 |
+
def register_schedule(self,
|
528 |
+
given_betas=None, beta_schedule="linear", timesteps=1000,
|
529 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
530 |
+
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
|
531 |
+
|
532 |
+
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
533 |
+
if self.shorten_cond_schedule:
|
534 |
+
self.make_cond_schedule()
|
535 |
+
|
536 |
+
def instantiate_first_stage(self, config):
|
537 |
+
model = instantiate_from_config(config)
|
538 |
+
self.first_stage_model = model.eval()
|
539 |
+
self.first_stage_model.train = disabled_train
|
540 |
+
for param in self.first_stage_model.parameters():
|
541 |
+
param.requires_grad = False
|
542 |
+
|
543 |
+
def instantiate_cond_stage(self, config):
|
544 |
+
if not self.cond_stage_trainable:
|
545 |
+
if config == "__is_first_stage__":
|
546 |
+
print("Using first stage also as cond stage.")
|
547 |
+
self.cond_stage_model = self.first_stage_model
|
548 |
+
elif config == "__is_unconditional__":
|
549 |
+
print(f"Training {self.__class__.__name__} as an unconditional model.")
|
550 |
+
self.cond_stage_model = None
|
551 |
+
# self.be_unconditional = True
|
552 |
+
else:
|
553 |
+
model = instantiate_from_config(config)
|
554 |
+
self.cond_stage_model = model.eval()
|
555 |
+
self.cond_stage_model.train = disabled_train
|
556 |
+
for param in self.cond_stage_model.parameters():
|
557 |
+
param.requires_grad = False
|
558 |
+
else:
|
559 |
+
assert config != '__is_first_stage__'
|
560 |
+
assert config != '__is_unconditional__'
|
561 |
+
model = instantiate_from_config(config)
|
562 |
+
self.cond_stage_model = model
|
563 |
+
|
564 |
+
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
|
565 |
+
denoise_row = []
|
566 |
+
for zd in tqdm(samples, desc=desc):
|
567 |
+
denoise_row.append(self.decode_first_stage(zd.to(self.device),
|
568 |
+
force_not_quantize=force_no_decoder_quantization))
|
569 |
+
n_imgs_per_row = len(denoise_row)
|
570 |
+
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
|
571 |
+
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
|
572 |
+
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
|
573 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
574 |
+
return denoise_grid
|
575 |
+
|
576 |
+
def get_first_stage_encoding(self, encoder_posterior):
|
577 |
+
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
578 |
+
z = encoder_posterior.sample()
|
579 |
+
elif isinstance(encoder_posterior, torch.Tensor):
|
580 |
+
z = encoder_posterior
|
581 |
+
else:
|
582 |
+
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
|
583 |
+
return self.scale_factor * z
|
584 |
+
|
585 |
+
def get_learned_conditioning(self, c):
|
586 |
+
if self.cond_stage_forward is None:
|
587 |
+
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
|
588 |
+
c = self.cond_stage_model.encode(c)
|
589 |
+
if isinstance(c, DiagonalGaussianDistribution):
|
590 |
+
c = c.mode()
|
591 |
+
else:
|
592 |
+
c = self.cond_stage_model(c)
|
593 |
+
else:
|
594 |
+
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
595 |
+
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
596 |
+
return c
|
597 |
+
|
598 |
+
def meshgrid(self, h, w):
|
599 |
+
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
|
600 |
+
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
|
601 |
+
|
602 |
+
arr = torch.cat([y, x], dim=-1)
|
603 |
+
return arr
|
604 |
+
|
605 |
+
def delta_border(self, h, w):
|
606 |
+
"""
|
607 |
+
:param h: height
|
608 |
+
:param w: width
|
609 |
+
:return: normalized distance to image border,
|
610 |
+
wtith min distance = 0 at border and max dist = 0.5 at image center
|
611 |
+
"""
|
612 |
+
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
|
613 |
+
arr = self.meshgrid(h, w) / lower_right_corner
|
614 |
+
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
|
615 |
+
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
|
616 |
+
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
|
617 |
+
return edge_dist
|
618 |
+
|
619 |
+
def get_weighting(self, h, w, Ly, Lx, device):
|
620 |
+
weighting = self.delta_border(h, w)
|
621 |
+
weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
|
622 |
+
self.split_input_params["clip_max_weight"], )
|
623 |
+
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
|
624 |
+
|
625 |
+
if self.split_input_params["tie_braker"]:
|
626 |
+
L_weighting = self.delta_border(Ly, Lx)
|
627 |
+
L_weighting = torch.clip(L_weighting,
|
628 |
+
self.split_input_params["clip_min_tie_weight"],
|
629 |
+
self.split_input_params["clip_max_tie_weight"])
|
630 |
+
|
631 |
+
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
|
632 |
+
weighting = weighting * L_weighting
|
633 |
+
return weighting
|
634 |
+
|
635 |
+
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
|
636 |
+
"""
|
637 |
+
:param x: img of size (bs, c, h, w)
|
638 |
+
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
639 |
+
"""
|
640 |
+
bs, nc, h, w = x.shape
|
641 |
+
|
642 |
+
# number of crops in image
|
643 |
+
Ly = (h - kernel_size[0]) // stride[0] + 1
|
644 |
+
Lx = (w - kernel_size[1]) // stride[1] + 1
|
645 |
+
|
646 |
+
if uf == 1 and df == 1:
|
647 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
648 |
+
unfold = torch.nn.Unfold(**fold_params)
|
649 |
+
|
650 |
+
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
|
651 |
+
|
652 |
+
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
|
653 |
+
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
|
654 |
+
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
|
655 |
+
|
656 |
+
elif uf > 1 and df == 1:
|
657 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
658 |
+
unfold = torch.nn.Unfold(**fold_params)
|
659 |
+
|
660 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
|
661 |
+
dilation=1, padding=0,
|
662 |
+
stride=(stride[0] * uf, stride[1] * uf))
|
663 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
|
664 |
+
|
665 |
+
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
|
666 |
+
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
|
667 |
+
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
|
668 |
+
|
669 |
+
elif df > 1 and uf == 1:
|
670 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
671 |
+
unfold = torch.nn.Unfold(**fold_params)
|
672 |
+
|
673 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
|
674 |
+
dilation=1, padding=0,
|
675 |
+
stride=(stride[0] // df, stride[1] // df))
|
676 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
|
677 |
+
|
678 |
+
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
|
679 |
+
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
|
680 |
+
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
|
681 |
+
|
682 |
+
else:
|
683 |
+
raise NotImplementedError
|
684 |
+
|
685 |
+
return fold, unfold, normalization, weighting
|
686 |
+
|
687 |
+
@torch.no_grad()
|
688 |
+
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
|
689 |
+
cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
|
690 |
+
x = super().get_input(batch, k)
|
691 |
+
if bs is not None:
|
692 |
+
x = x[:bs]
|
693 |
+
x = x.to(self.device)
|
694 |
+
encoder_posterior = self.encode_first_stage(x)
|
695 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
696 |
+
cond_key = cond_key or self.cond_stage_key
|
697 |
+
xc = super().get_input(batch, cond_key)
|
698 |
+
if bs is not None:
|
699 |
+
xc["c_crossattn"] = xc["c_crossattn"][:bs]
|
700 |
+
xc["c_concat"] = xc["c_concat"][:bs]
|
701 |
+
cond = {}
|
702 |
+
|
703 |
+
# To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
|
704 |
+
random = torch.rand(x.size(0), device=x.device)
|
705 |
+
prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
|
706 |
+
input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
|
707 |
+
|
708 |
+
null_prompt = self.get_learned_conditioning([""])
|
709 |
+
cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]
|
710 |
+
cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()]
|
711 |
+
|
712 |
+
out = [z, cond]
|
713 |
+
if return_first_stage_outputs:
|
714 |
+
xrec = self.decode_first_stage(z)
|
715 |
+
out.extend([x, xrec])
|
716 |
+
if return_original_cond:
|
717 |
+
out.append(xc)
|
718 |
+
return out
|
719 |
+
|
720 |
+
@torch.no_grad()
|
721 |
+
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
722 |
+
if predict_cids:
|
723 |
+
if z.dim() == 4:
|
724 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
725 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
726 |
+
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
727 |
+
|
728 |
+
z = 1. / self.scale_factor * z
|
729 |
+
|
730 |
+
if hasattr(self, "split_input_params"):
|
731 |
+
if self.split_input_params["patch_distributed_vq"]:
|
732 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
733 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
734 |
+
uf = self.split_input_params["vqf"]
|
735 |
+
bs, nc, h, w = z.shape
|
736 |
+
if ks[0] > h or ks[1] > w:
|
737 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
738 |
+
print("reducing Kernel")
|
739 |
+
|
740 |
+
if stride[0] > h or stride[1] > w:
|
741 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
742 |
+
print("reducing stride")
|
743 |
+
|
744 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
745 |
+
|
746 |
+
z = unfold(z) # (bn, nc * prod(**ks), L)
|
747 |
+
# 1. Reshape to img shape
|
748 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
749 |
+
|
750 |
+
# 2. apply model loop over last dim
|
751 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
752 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
753 |
+
force_not_quantize=predict_cids or force_not_quantize)
|
754 |
+
for i in range(z.shape[-1])]
|
755 |
+
else:
|
756 |
+
|
757 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
758 |
+
for i in range(z.shape[-1])]
|
759 |
+
|
760 |
+
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
761 |
+
o = o * weighting
|
762 |
+
# Reverse 1. reshape to img shape
|
763 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
764 |
+
# stitch crops together
|
765 |
+
decoded = fold(o)
|
766 |
+
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
767 |
+
return decoded
|
768 |
+
else:
|
769 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
770 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
771 |
+
else:
|
772 |
+
return self.first_stage_model.decode(z)
|
773 |
+
|
774 |
+
else:
|
775 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
776 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
777 |
+
else:
|
778 |
+
return self.first_stage_model.decode(z)
|
779 |
+
|
780 |
+
# same as above but without decorator
|
781 |
+
def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
782 |
+
if predict_cids:
|
783 |
+
if z.dim() == 4:
|
784 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
785 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
786 |
+
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
787 |
+
|
788 |
+
z = 1. / self.scale_factor * z
|
789 |
+
|
790 |
+
if hasattr(self, "split_input_params"):
|
791 |
+
if self.split_input_params["patch_distributed_vq"]:
|
792 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
793 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
794 |
+
uf = self.split_input_params["vqf"]
|
795 |
+
bs, nc, h, w = z.shape
|
796 |
+
if ks[0] > h or ks[1] > w:
|
797 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
798 |
+
print("reducing Kernel")
|
799 |
+
|
800 |
+
if stride[0] > h or stride[1] > w:
|
801 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
802 |
+
print("reducing stride")
|
803 |
+
|
804 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
805 |
+
|
806 |
+
z = unfold(z) # (bn, nc * prod(**ks), L)
|
807 |
+
# 1. Reshape to img shape
|
808 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
809 |
+
|
810 |
+
# 2. apply model loop over last dim
|
811 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
812 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
813 |
+
force_not_quantize=predict_cids or force_not_quantize)
|
814 |
+
for i in range(z.shape[-1])]
|
815 |
+
else:
|
816 |
+
|
817 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
818 |
+
for i in range(z.shape[-1])]
|
819 |
+
|
820 |
+
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
821 |
+
o = o * weighting
|
822 |
+
# Reverse 1. reshape to img shape
|
823 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
824 |
+
# stitch crops together
|
825 |
+
decoded = fold(o)
|
826 |
+
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
827 |
+
return decoded
|
828 |
+
else:
|
829 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
830 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
831 |
+
else:
|
832 |
+
return self.first_stage_model.decode(z)
|
833 |
+
|
834 |
+
else:
|
835 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
836 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
837 |
+
else:
|
838 |
+
return self.first_stage_model.decode(z)
|
839 |
+
|
840 |
+
@torch.no_grad()
|
841 |
+
def encode_first_stage(self, x):
|
842 |
+
if hasattr(self, "split_input_params"):
|
843 |
+
if self.split_input_params["patch_distributed_vq"]:
|
844 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
845 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
846 |
+
df = self.split_input_params["vqf"]
|
847 |
+
self.split_input_params['original_image_size'] = x.shape[-2:]
|
848 |
+
bs, nc, h, w = x.shape
|
849 |
+
if ks[0] > h or ks[1] > w:
|
850 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
851 |
+
print("reducing Kernel")
|
852 |
+
|
853 |
+
if stride[0] > h or stride[1] > w:
|
854 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
855 |
+
print("reducing stride")
|
856 |
+
|
857 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
|
858 |
+
z = unfold(x) # (bn, nc * prod(**ks), L)
|
859 |
+
# Reshape to img shape
|
860 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
861 |
+
|
862 |
+
output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
|
863 |
+
for i in range(z.shape[-1])]
|
864 |
+
|
865 |
+
o = torch.stack(output_list, axis=-1)
|
866 |
+
o = o * weighting
|
867 |
+
|
868 |
+
# Reverse reshape to img shape
|
869 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
870 |
+
# stitch crops together
|
871 |
+
decoded = fold(o)
|
872 |
+
decoded = decoded / normalization
|
873 |
+
return decoded
|
874 |
+
|
875 |
+
else:
|
876 |
+
return self.first_stage_model.encode(x)
|
877 |
+
else:
|
878 |
+
return self.first_stage_model.encode(x)
|
879 |
+
|
880 |
+
def shared_step(self, batch, **kwargs):
|
881 |
+
x, c = self.get_input(batch, self.first_stage_key)
|
882 |
+
loss = self(x, c)
|
883 |
+
return loss
|
884 |
+
|
885 |
+
def forward(self, x, c, *args, **kwargs):
|
886 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
887 |
+
if self.model.conditioning_key is not None:
|
888 |
+
assert c is not None
|
889 |
+
if self.cond_stage_trainable:
|
890 |
+
c = self.get_learned_conditioning(c)
|
891 |
+
if self.shorten_cond_schedule: # TODO: drop this option
|
892 |
+
tc = self.cond_ids[t].to(self.device)
|
893 |
+
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
894 |
+
return self.p_losses(x, c, t, *args, **kwargs)
|
895 |
+
|
896 |
+
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
897 |
+
|
898 |
+
if isinstance(cond, dict):
|
899 |
+
# hybrid case, cond is exptected to be a dict
|
900 |
+
pass
|
901 |
+
else:
|
902 |
+
if not isinstance(cond, list):
|
903 |
+
cond = [cond]
|
904 |
+
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
|
905 |
+
cond = {key: cond}
|
906 |
+
|
907 |
+
if hasattr(self, "split_input_params"):
|
908 |
+
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
909 |
+
assert not return_ids
|
910 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
911 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
912 |
+
|
913 |
+
h, w = x_noisy.shape[-2:]
|
914 |
+
|
915 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
|
916 |
+
|
917 |
+
z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
|
918 |
+
# Reshape to img shape
|
919 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
920 |
+
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
|
921 |
+
|
922 |
+
if self.cond_stage_key in ["image", "LR_image", "segmentation",
|
923 |
+
'bbox_img'] and self.model.conditioning_key: # todo check for completeness
|
924 |
+
c_key = next(iter(cond.keys())) # get key
|
925 |
+
c = next(iter(cond.values())) # get value
|
926 |
+
assert (len(c) == 1) # todo extend to list with more than one elem
|
927 |
+
c = c[0] # get element
|
928 |
+
|
929 |
+
c = unfold(c)
|
930 |
+
c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
931 |
+
|
932 |
+
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
|
933 |
+
|
934 |
+
elif self.cond_stage_key == 'coordinates_bbox':
|
935 |
+
assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
|
936 |
+
|
937 |
+
# assuming padding of unfold is always 0 and its dilation is always 1
|
938 |
+
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
|
939 |
+
full_img_h, full_img_w = self.split_input_params['original_image_size']
|
940 |
+
# as we are operating on latents, we need the factor from the original image size to the
|
941 |
+
# spatial latent size to properly rescale the crops for regenerating the bbox annotations
|
942 |
+
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
943 |
+
rescale_latent = 2 ** (num_downs)
|
944 |
+
|
945 |
+
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
|
946 |
+
# need to rescale the tl patch coordinates to be in between (0,1)
|
947 |
+
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
|
948 |
+
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
|
949 |
+
for patch_nr in range(z.shape[-1])]
|
950 |
+
|
951 |
+
# patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
|
952 |
+
patch_limits = [(x_tl, y_tl,
|
953 |
+
rescale_latent * ks[0] / full_img_w,
|
954 |
+
rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
|
955 |
+
# patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
|
956 |
+
|
957 |
+
# tokenize crop coordinates for the bounding boxes of the respective patches
|
958 |
+
patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
|
959 |
+
for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
|
960 |
+
print(patch_limits_tknzd[0].shape)
|
961 |
+
# cut tknzd crop position from conditioning
|
962 |
+
assert isinstance(cond, dict), 'cond must be dict to be fed into model'
|
963 |
+
cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
|
964 |
+
print(cut_cond.shape)
|
965 |
+
|
966 |
+
adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
|
967 |
+
adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
|
968 |
+
print(adapted_cond.shape)
|
969 |
+
adapted_cond = self.get_learned_conditioning(adapted_cond)
|
970 |
+
print(adapted_cond.shape)
|
971 |
+
adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
|
972 |
+
print(adapted_cond.shape)
|
973 |
+
|
974 |
+
cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
|
975 |
+
|
976 |
+
else:
|
977 |
+
cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
|
978 |
+
|
979 |
+
# apply model by loop over crops
|
980 |
+
output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
|
981 |
+
assert not isinstance(output_list[0],
|
982 |
+
tuple) # todo cant deal with multiple model outputs check this never happens
|
983 |
+
|
984 |
+
o = torch.stack(output_list, axis=-1)
|
985 |
+
o = o * weighting
|
986 |
+
# Reverse reshape to img shape
|
987 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
988 |
+
# stitch crops together
|
989 |
+
x_recon = fold(o) / normalization
|
990 |
+
|
991 |
+
else:
|
992 |
+
x_recon = self.model(x_noisy, t, **cond)
|
993 |
+
|
994 |
+
if isinstance(x_recon, tuple) and not return_ids:
|
995 |
+
return x_recon[0]
|
996 |
+
else:
|
997 |
+
return x_recon
|
998 |
+
|
999 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
1000 |
+
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
|
1001 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
1002 |
+
|
1003 |
+
def _prior_bpd(self, x_start):
|
1004 |
+
"""
|
1005 |
+
Get the prior KL term for the variational lower-bound, measured in
|
1006 |
+
bits-per-dim.
|
1007 |
+
This term can't be optimized, as it only depends on the encoder.
|
1008 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
1009 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
1010 |
+
"""
|
1011 |
+
batch_size = x_start.shape[0]
|
1012 |
+
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
1013 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
1014 |
+
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
1015 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
1016 |
+
|
1017 |
+
def p_losses(self, x_start, cond, t, noise=None):
|
1018 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
1019 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
1020 |
+
model_output = self.apply_model(x_noisy, t, cond)
|
1021 |
+
|
1022 |
+
loss_dict = {}
|
1023 |
+
prefix = 'train' if self.training else 'val'
|
1024 |
+
|
1025 |
+
if self.parameterization == "x0":
|
1026 |
+
target = x_start
|
1027 |
+
elif self.parameterization == "eps":
|
1028 |
+
target = noise
|
1029 |
+
else:
|
1030 |
+
raise NotImplementedError()
|
1031 |
+
|
1032 |
+
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
|
1033 |
+
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
|
1034 |
+
|
1035 |
+
logvar_t = self.logvar[t].to(self.device)
|
1036 |
+
loss = loss_simple / torch.exp(logvar_t) + logvar_t
|
1037 |
+
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
|
1038 |
+
if self.learn_logvar:
|
1039 |
+
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
|
1040 |
+
loss_dict.update({'logvar': self.logvar.data.mean()})
|
1041 |
+
|
1042 |
+
loss = self.l_simple_weight * loss.mean()
|
1043 |
+
|
1044 |
+
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
|
1045 |
+
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
|
1046 |
+
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
|
1047 |
+
loss += (self.original_elbo_weight * loss_vlb)
|
1048 |
+
loss_dict.update({f'{prefix}/loss': loss})
|
1049 |
+
|
1050 |
+
return loss, loss_dict
|
1051 |
+
|
1052 |
+
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
|
1053 |
+
return_x0=False, score_corrector=None, corrector_kwargs=None):
|
1054 |
+
t_in = t
|
1055 |
+
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
|
1056 |
+
|
1057 |
+
if score_corrector is not None:
|
1058 |
+
assert self.parameterization == "eps"
|
1059 |
+
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
|
1060 |
+
|
1061 |
+
if return_codebook_ids:
|
1062 |
+
model_out, logits = model_out
|
1063 |
+
|
1064 |
+
if self.parameterization == "eps":
|
1065 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
1066 |
+
elif self.parameterization == "x0":
|
1067 |
+
x_recon = model_out
|
1068 |
+
else:
|
1069 |
+
raise NotImplementedError()
|
1070 |
+
|
1071 |
+
if clip_denoised:
|
1072 |
+
x_recon.clamp_(-1., 1.)
|
1073 |
+
if quantize_denoised:
|
1074 |
+
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
|
1075 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
1076 |
+
if return_codebook_ids:
|
1077 |
+
return model_mean, posterior_variance, posterior_log_variance, logits
|
1078 |
+
elif return_x0:
|
1079 |
+
return model_mean, posterior_variance, posterior_log_variance, x_recon
|
1080 |
+
else:
|
1081 |
+
return model_mean, posterior_variance, posterior_log_variance
|
1082 |
+
|
1083 |
+
@torch.no_grad()
|
1084 |
+
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
|
1085 |
+
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
|
1086 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
|
1087 |
+
b, *_, device = *x.shape, x.device
|
1088 |
+
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
|
1089 |
+
return_codebook_ids=return_codebook_ids,
|
1090 |
+
quantize_denoised=quantize_denoised,
|
1091 |
+
return_x0=return_x0,
|
1092 |
+
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
|
1093 |
+
if return_codebook_ids:
|
1094 |
+
raise DeprecationWarning("Support dropped.")
|
1095 |
+
model_mean, _, model_log_variance, logits = outputs
|
1096 |
+
elif return_x0:
|
1097 |
+
model_mean, _, model_log_variance, x0 = outputs
|
1098 |
+
else:
|
1099 |
+
model_mean, _, model_log_variance = outputs
|
1100 |
+
|
1101 |
+
noise = noise_like(x.shape, device, repeat_noise) * temperature
|
1102 |
+
if noise_dropout > 0.:
|
1103 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
1104 |
+
# no noise when t == 0
|
1105 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
1106 |
+
|
1107 |
+
if return_codebook_ids:
|
1108 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
|
1109 |
+
if return_x0:
|
1110 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
|
1111 |
+
else:
|
1112 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
1113 |
+
|
1114 |
+
@torch.no_grad()
|
1115 |
+
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
|
1116 |
+
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
|
1117 |
+
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
|
1118 |
+
log_every_t=None):
|
1119 |
+
if not log_every_t:
|
1120 |
+
log_every_t = self.log_every_t
|
1121 |
+
timesteps = self.num_timesteps
|
1122 |
+
if batch_size is not None:
|
1123 |
+
b = batch_size if batch_size is not None else shape[0]
|
1124 |
+
shape = [batch_size] + list(shape)
|
1125 |
+
else:
|
1126 |
+
b = batch_size = shape[0]
|
1127 |
+
if x_T is None:
|
1128 |
+
img = torch.randn(shape, device=self.device)
|
1129 |
+
else:
|
1130 |
+
img = x_T
|
1131 |
+
intermediates = []
|
1132 |
+
if cond is not None:
|
1133 |
+
if isinstance(cond, dict):
|
1134 |
+
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
1135 |
+
[x[:batch_size] for x in cond[key]] for key in cond}
|
1136 |
+
else:
|
1137 |
+
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
1138 |
+
|
1139 |
+
if start_T is not None:
|
1140 |
+
timesteps = min(timesteps, start_T)
|
1141 |
+
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
|
1142 |
+
total=timesteps) if verbose else reversed(
|
1143 |
+
range(0, timesteps))
|
1144 |
+
if type(temperature) == float:
|
1145 |
+
temperature = [temperature] * timesteps
|
1146 |
+
|
1147 |
+
for i in iterator:
|
1148 |
+
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
|
1149 |
+
if self.shorten_cond_schedule:
|
1150 |
+
assert self.model.conditioning_key != 'hybrid'
|
1151 |
+
tc = self.cond_ids[ts].to(cond.device)
|
1152 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
1153 |
+
|
1154 |
+
img, x0_partial = self.p_sample(img, cond, ts,
|
1155 |
+
clip_denoised=self.clip_denoised,
|
1156 |
+
quantize_denoised=quantize_denoised, return_x0=True,
|
1157 |
+
temperature=temperature[i], noise_dropout=noise_dropout,
|
1158 |
+
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
|
1159 |
+
if mask is not None:
|
1160 |
+
assert x0 is not None
|
1161 |
+
img_orig = self.q_sample(x0, ts)
|
1162 |
+
img = img_orig * mask + (1. - mask) * img
|
1163 |
+
|
1164 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
1165 |
+
intermediates.append(x0_partial)
|
1166 |
+
if callback:
|
1167 |
+
callback(i)
|
1168 |
+
if img_callback:
|
1169 |
+
img_callback(img, i)
|
1170 |
+
return img, intermediates
|
1171 |
+
|
1172 |
+
@torch.no_grad()
|
1173 |
+
def p_sample_loop(self, cond, shape, return_intermediates=False,
|
1174 |
+
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
|
1175 |
+
mask=None, x0=None, img_callback=None, start_T=None,
|
1176 |
+
log_every_t=None):
|
1177 |
+
|
1178 |
+
if not log_every_t:
|
1179 |
+
log_every_t = self.log_every_t
|
1180 |
+
device = self.betas.device
|
1181 |
+
b = shape[0]
|
1182 |
+
if x_T is None:
|
1183 |
+
img = torch.randn(shape, device=device)
|
1184 |
+
else:
|
1185 |
+
img = x_T
|
1186 |
+
|
1187 |
+
intermediates = [img]
|
1188 |
+
if timesteps is None:
|
1189 |
+
timesteps = self.num_timesteps
|
1190 |
+
|
1191 |
+
if start_T is not None:
|
1192 |
+
timesteps = min(timesteps, start_T)
|
1193 |
+
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
|
1194 |
+
range(0, timesteps))
|
1195 |
+
|
1196 |
+
if mask is not None:
|
1197 |
+
assert x0 is not None
|
1198 |
+
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
|
1199 |
+
|
1200 |
+
for i in iterator:
|
1201 |
+
ts = torch.full((b,), i, device=device, dtype=torch.long)
|
1202 |
+
if self.shorten_cond_schedule:
|
1203 |
+
assert self.model.conditioning_key != 'hybrid'
|
1204 |
+
tc = self.cond_ids[ts].to(cond.device)
|
1205 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
1206 |
+
|
1207 |
+
img = self.p_sample(img, cond, ts,
|
1208 |
+
clip_denoised=self.clip_denoised,
|
1209 |
+
quantize_denoised=quantize_denoised)
|
1210 |
+
if mask is not None:
|
1211 |
+
img_orig = self.q_sample(x0, ts)
|
1212 |
+
img = img_orig * mask + (1. - mask) * img
|
1213 |
+
|
1214 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
1215 |
+
intermediates.append(img)
|
1216 |
+
if callback:
|
1217 |
+
callback(i)
|
1218 |
+
if img_callback:
|
1219 |
+
img_callback(img, i)
|
1220 |
+
|
1221 |
+
if return_intermediates:
|
1222 |
+
return img, intermediates
|
1223 |
+
return img
|
1224 |
+
|
1225 |
+
@torch.no_grad()
|
1226 |
+
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
|
1227 |
+
verbose=True, timesteps=None, quantize_denoised=False,
|
1228 |
+
mask=None, x0=None, shape=None,**kwargs):
|
1229 |
+
if shape is None:
|
1230 |
+
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
1231 |
+
if cond is not None:
|
1232 |
+
if isinstance(cond, dict):
|
1233 |
+
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
1234 |
+
[x[:batch_size] for x in cond[key]] for key in cond}
|
1235 |
+
else:
|
1236 |
+
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
1237 |
+
return self.p_sample_loop(cond,
|
1238 |
+
shape,
|
1239 |
+
return_intermediates=return_intermediates, x_T=x_T,
|
1240 |
+
verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
|
1241 |
+
mask=mask, x0=x0)
|
1242 |
+
|
1243 |
+
@torch.no_grad()
|
1244 |
+
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
|
1245 |
+
|
1246 |
+
if ddim:
|
1247 |
+
ddim_sampler = DDIMSampler(self)
|
1248 |
+
shape = (self.channels, self.image_size, self.image_size)
|
1249 |
+
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
|
1250 |
+
shape,cond,verbose=False,**kwargs)
|
1251 |
+
|
1252 |
+
else:
|
1253 |
+
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
|
1254 |
+
return_intermediates=True,**kwargs)
|
1255 |
+
|
1256 |
+
return samples, intermediates
|
1257 |
+
|
1258 |
+
|
1259 |
+
@torch.no_grad()
|
1260 |
+
def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
1261 |
+
quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
|
1262 |
+
plot_diffusion_rows=False, **kwargs):
|
1263 |
+
|
1264 |
+
use_ddim = False
|
1265 |
+
|
1266 |
+
log = {}
|
1267 |
+
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
1268 |
+
return_first_stage_outputs=True,
|
1269 |
+
force_c_encode=True,
|
1270 |
+
return_original_cond=True,
|
1271 |
+
bs=N, uncond=0)
|
1272 |
+
N = min(x.shape[0], N)
|
1273 |
+
n_row = min(x.shape[0], n_row)
|
1274 |
+
log["inputs"] = x
|
1275 |
+
log["reals"] = xc["c_concat"]
|
1276 |
+
log["reconstruction"] = xrec
|
1277 |
+
if self.model.conditioning_key is not None:
|
1278 |
+
if hasattr(self.cond_stage_model, "decode"):
|
1279 |
+
xc = self.cond_stage_model.decode(c)
|
1280 |
+
log["conditioning"] = xc
|
1281 |
+
elif self.cond_stage_key in ["caption"]:
|
1282 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
|
1283 |
+
log["conditioning"] = xc
|
1284 |
+
elif self.cond_stage_key == 'class_label':
|
1285 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
1286 |
+
log['conditioning'] = xc
|
1287 |
+
elif isimage(xc):
|
1288 |
+
log["conditioning"] = xc
|
1289 |
+
if ismap(xc):
|
1290 |
+
log["original_conditioning"] = self.to_rgb(xc)
|
1291 |
+
|
1292 |
+
if plot_diffusion_rows:
|
1293 |
+
# get diffusion row
|
1294 |
+
diffusion_row = []
|
1295 |
+
z_start = z[:n_row]
|
1296 |
+
for t in range(self.num_timesteps):
|
1297 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
1298 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
1299 |
+
t = t.to(self.device).long()
|
1300 |
+
noise = torch.randn_like(z_start)
|
1301 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
1302 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
1303 |
+
|
1304 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
1305 |
+
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
1306 |
+
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
1307 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
1308 |
+
log["diffusion_row"] = diffusion_grid
|
1309 |
+
|
1310 |
+
if sample:
|
1311 |
+
# get denoise row
|
1312 |
+
with self.ema_scope("Plotting"):
|
1313 |
+
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
1314 |
+
ddim_steps=ddim_steps,eta=ddim_eta)
|
1315 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
1316 |
+
x_samples = self.decode_first_stage(samples)
|
1317 |
+
log["samples"] = x_samples
|
1318 |
+
if plot_denoise_rows:
|
1319 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
1320 |
+
log["denoise_row"] = denoise_grid
|
1321 |
+
|
1322 |
+
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
|
1323 |
+
self.first_stage_model, IdentityFirstStage):
|
1324 |
+
# also display when quantizing x0 while sampling
|
1325 |
+
with self.ema_scope("Plotting Quantized Denoised"):
|
1326 |
+
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
1327 |
+
ddim_steps=ddim_steps,eta=ddim_eta,
|
1328 |
+
quantize_denoised=True)
|
1329 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
1330 |
+
# quantize_denoised=True)
|
1331 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1332 |
+
log["samples_x0_quantized"] = x_samples
|
1333 |
+
|
1334 |
+
if inpaint:
|
1335 |
+
# make a simple center square
|
1336 |
+
h, w = z.shape[2], z.shape[3]
|
1337 |
+
mask = torch.ones(N, h, w).to(self.device)
|
1338 |
+
# zeros will be filled in
|
1339 |
+
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
1340 |
+
mask = mask[:, None, ...]
|
1341 |
+
with self.ema_scope("Plotting Inpaint"):
|
1342 |
+
|
1343 |
+
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
|
1344 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
1345 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1346 |
+
log["samples_inpainting"] = x_samples
|
1347 |
+
log["mask"] = mask
|
1348 |
+
|
1349 |
+
# outpaint
|
1350 |
+
with self.ema_scope("Plotting Outpaint"):
|
1351 |
+
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
|
1352 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
1353 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1354 |
+
log["samples_outpainting"] = x_samples
|
1355 |
+
|
1356 |
+
if plot_progressive_rows:
|
1357 |
+
with self.ema_scope("Plotting Progressives"):
|
1358 |
+
img, progressives = self.progressive_denoising(c,
|
1359 |
+
shape=(self.channels, self.image_size, self.image_size),
|
1360 |
+
batch_size=N)
|
1361 |
+
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
|
1362 |
+
log["progressive_row"] = prog_row
|
1363 |
+
|
1364 |
+
if return_keys:
|
1365 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
1366 |
+
return log
|
1367 |
+
else:
|
1368 |
+
return {key: log[key] for key in return_keys}
|
1369 |
+
return log
|
1370 |
+
|
1371 |
+
def configure_optimizers(self):
|
1372 |
+
lr = self.learning_rate
|
1373 |
+
params = list(self.model.parameters())
|
1374 |
+
if self.cond_stage_trainable:
|
1375 |
+
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
1376 |
+
params = params + list(self.cond_stage_model.parameters())
|
1377 |
+
if self.learn_logvar:
|
1378 |
+
print('Diffusion model optimizing logvar')
|
1379 |
+
params.append(self.logvar)
|
1380 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
1381 |
+
if self.use_scheduler:
|
1382 |
+
assert 'target' in self.scheduler_config
|
1383 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
1384 |
+
|
1385 |
+
print("Setting up LambdaLR scheduler...")
|
1386 |
+
scheduler = [
|
1387 |
+
{
|
1388 |
+
'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
|
1389 |
+
'interval': 'step',
|
1390 |
+
'frequency': 1
|
1391 |
+
}]
|
1392 |
+
return [opt], scheduler
|
1393 |
+
return opt
|
1394 |
+
|
1395 |
+
@torch.no_grad()
|
1396 |
+
def to_rgb(self, x):
|
1397 |
+
x = x.float()
|
1398 |
+
if not hasattr(self, "colorize"):
|
1399 |
+
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
|
1400 |
+
x = nn.functional.conv2d(x, weight=self.colorize)
|
1401 |
+
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
1402 |
+
return x
|
1403 |
+
|
1404 |
+
|
1405 |
+
class DiffusionWrapper(pl.LightningModule):
|
1406 |
+
def __init__(self, diff_model_config, conditioning_key):
|
1407 |
+
super().__init__()
|
1408 |
+
self.diffusion_model = instantiate_from_config(diff_model_config)
|
1409 |
+
self.conditioning_key = conditioning_key
|
1410 |
+
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
|
1411 |
+
|
1412 |
+
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
|
1413 |
+
if self.conditioning_key is None:
|
1414 |
+
out = self.diffusion_model(x, t)
|
1415 |
+
elif self.conditioning_key == 'concat':
|
1416 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
1417 |
+
out = self.diffusion_model(xc, t)
|
1418 |
+
elif self.conditioning_key == 'crossattn':
|
1419 |
+
cc = torch.cat(c_crossattn, 1)
|
1420 |
+
out = self.diffusion_model(x, t, context=cc)
|
1421 |
+
elif self.conditioning_key == 'hybrid':
|
1422 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
1423 |
+
cc = torch.cat(c_crossattn, 1)
|
1424 |
+
out = self.diffusion_model(xc, t, context=cc)
|
1425 |
+
elif self.conditioning_key == 'adm':
|
1426 |
+
cc = c_crossattn[0]
|
1427 |
+
out = self.diffusion_model(x, t, y=cc)
|
1428 |
+
else:
|
1429 |
+
raise NotImplementedError()
|
1430 |
+
|
1431 |
+
return out
|
1432 |
+
|
1433 |
+
|
1434 |
+
class Layout2ImgDiffusion(LatentDiffusion):
|
1435 |
+
# TODO: move all layout-specific hacks to this class
|
1436 |
+
def __init__(self, cond_stage_key, *args, **kwargs):
|
1437 |
+
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
1438 |
+
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
1439 |
+
|
1440 |
+
def log_images(self, batch, N=8, *args, **kwargs):
|
1441 |
+
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
1442 |
+
|
1443 |
+
key = 'train' if self.training else 'validation'
|
1444 |
+
dset = self.trainer.datamodule.datasets[key]
|
1445 |
+
mapper = dset.conditional_builders[self.cond_stage_key]
|
1446 |
+
|
1447 |
+
bbox_imgs = []
|
1448 |
+
map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
|
1449 |
+
for tknzd_bbox in batch[self.cond_stage_key][:N]:
|
1450 |
+
bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
|
1451 |
+
bbox_imgs.append(bboximg)
|
1452 |
+
|
1453 |
+
cond_img = torch.stack(bbox_imgs, dim=0)
|
1454 |
+
logs['bbox_image'] = cond_img
|
1455 |
+
return logs
|
modules/models/diffusion/uni_pc/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sampler import UniPCSampler # noqa: F401
|
modules/models/diffusion/uni_pc/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (204 Bytes). View file
|
|
modules/models/diffusion/uni_pc/__pycache__/sampler.cpython-310.pyc
ADDED
Binary file (3.21 kB). View file
|
|
modules/models/diffusion/uni_pc/__pycache__/uni_pc.cpython-310.pyc
ADDED
Binary file (27.6 kB). View file
|
|
modules/models/diffusion/uni_pc/sampler.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SAMPLING ONLY."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
|
6 |
+
from modules import shared, devices
|
7 |
+
|
8 |
+
|
9 |
+
class UniPCSampler(object):
|
10 |
+
def __init__(self, model, **kwargs):
|
11 |
+
super().__init__()
|
12 |
+
self.model = model
|
13 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
14 |
+
self.before_sample = None
|
15 |
+
self.after_sample = None
|
16 |
+
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
17 |
+
|
18 |
+
def register_buffer(self, name, attr):
|
19 |
+
if type(attr) == torch.Tensor:
|
20 |
+
if attr.device != devices.device:
|
21 |
+
attr = attr.to(devices.device)
|
22 |
+
setattr(self, name, attr)
|
23 |
+
|
24 |
+
def set_hooks(self, before_sample, after_sample, after_update):
|
25 |
+
self.before_sample = before_sample
|
26 |
+
self.after_sample = after_sample
|
27 |
+
self.after_update = after_update
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def sample(self,
|
31 |
+
S,
|
32 |
+
batch_size,
|
33 |
+
shape,
|
34 |
+
conditioning=None,
|
35 |
+
callback=None,
|
36 |
+
normals_sequence=None,
|
37 |
+
img_callback=None,
|
38 |
+
quantize_x0=False,
|
39 |
+
eta=0.,
|
40 |
+
mask=None,
|
41 |
+
x0=None,
|
42 |
+
temperature=1.,
|
43 |
+
noise_dropout=0.,
|
44 |
+
score_corrector=None,
|
45 |
+
corrector_kwargs=None,
|
46 |
+
verbose=True,
|
47 |
+
x_T=None,
|
48 |
+
log_every_t=100,
|
49 |
+
unconditional_guidance_scale=1.,
|
50 |
+
unconditional_conditioning=None,
|
51 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
52 |
+
**kwargs
|
53 |
+
):
|
54 |
+
if conditioning is not None:
|
55 |
+
if isinstance(conditioning, dict):
|
56 |
+
ctmp = conditioning[list(conditioning.keys())[0]]
|
57 |
+
while isinstance(ctmp, list):
|
58 |
+
ctmp = ctmp[0]
|
59 |
+
cbs = ctmp.shape[0]
|
60 |
+
if cbs != batch_size:
|
61 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
62 |
+
|
63 |
+
elif isinstance(conditioning, list):
|
64 |
+
for ctmp in conditioning:
|
65 |
+
if ctmp.shape[0] != batch_size:
|
66 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
67 |
+
|
68 |
+
else:
|
69 |
+
if conditioning.shape[0] != batch_size:
|
70 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
71 |
+
|
72 |
+
# sampling
|
73 |
+
C, H, W = shape
|
74 |
+
size = (batch_size, C, H, W)
|
75 |
+
# print(f'Data shape for UniPC sampling is {size}')
|
76 |
+
|
77 |
+
device = self.model.betas.device
|
78 |
+
if x_T is None:
|
79 |
+
img = torch.randn(size, device=device)
|
80 |
+
else:
|
81 |
+
img = x_T
|
82 |
+
|
83 |
+
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
84 |
+
|
85 |
+
# SD 1.X is "noise", SD 2.X is "v"
|
86 |
+
model_type = "v" if self.model.parameterization == "v" else "noise"
|
87 |
+
|
88 |
+
model_fn = model_wrapper(
|
89 |
+
lambda x, t, c: self.model.apply_model(x, t, c),
|
90 |
+
ns,
|
91 |
+
model_type=model_type,
|
92 |
+
guidance_type="classifier-free",
|
93 |
+
#condition=conditioning,
|
94 |
+
#unconditional_condition=unconditional_conditioning,
|
95 |
+
guidance_scale=unconditional_guidance_scale,
|
96 |
+
)
|
97 |
+
|
98 |
+
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
|
99 |
+
x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
|
100 |
+
|
101 |
+
return x.to(device), None
|
modules/models/diffusion/uni_pc/uni_pc.py
ADDED
@@ -0,0 +1,863 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import tqdm
|
4 |
+
|
5 |
+
|
6 |
+
class NoiseScheduleVP:
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
schedule='discrete',
|
10 |
+
betas=None,
|
11 |
+
alphas_cumprod=None,
|
12 |
+
continuous_beta_0=0.1,
|
13 |
+
continuous_beta_1=20.,
|
14 |
+
):
|
15 |
+
"""Create a wrapper class for the forward SDE (VP type).
|
16 |
+
|
17 |
+
***
|
18 |
+
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
19 |
+
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
20 |
+
***
|
21 |
+
|
22 |
+
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
23 |
+
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
24 |
+
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
25 |
+
|
26 |
+
log_alpha_t = self.marginal_log_mean_coeff(t)
|
27 |
+
sigma_t = self.marginal_std(t)
|
28 |
+
lambda_t = self.marginal_lambda(t)
|
29 |
+
|
30 |
+
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
31 |
+
|
32 |
+
t = self.inverse_lambda(lambda_t)
|
33 |
+
|
34 |
+
===============================================================
|
35 |
+
|
36 |
+
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
37 |
+
|
38 |
+
1. For discrete-time DPMs:
|
39 |
+
|
40 |
+
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
41 |
+
t_i = (i + 1) / N
|
42 |
+
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
43 |
+
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
47 |
+
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
48 |
+
|
49 |
+
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
50 |
+
|
51 |
+
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
52 |
+
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
53 |
+
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
54 |
+
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
55 |
+
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
56 |
+
and
|
57 |
+
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
58 |
+
|
59 |
+
|
60 |
+
2. For continuous-time DPMs:
|
61 |
+
|
62 |
+
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
63 |
+
schedule are the default settings in DDPM and improved-DDPM:
|
64 |
+
|
65 |
+
Args:
|
66 |
+
beta_min: A `float` number. The smallest beta for the linear schedule.
|
67 |
+
beta_max: A `float` number. The largest beta for the linear schedule.
|
68 |
+
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
69 |
+
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
70 |
+
T: A `float` number. The ending time of the forward process.
|
71 |
+
|
72 |
+
===============================================================
|
73 |
+
|
74 |
+
Args:
|
75 |
+
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
76 |
+
'linear' or 'cosine' for continuous-time DPMs.
|
77 |
+
Returns:
|
78 |
+
A wrapper object of the forward SDE (VP type).
|
79 |
+
|
80 |
+
===============================================================
|
81 |
+
|
82 |
+
Example:
|
83 |
+
|
84 |
+
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
85 |
+
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
86 |
+
|
87 |
+
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
88 |
+
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
89 |
+
|
90 |
+
# For continuous-time DPMs (VPSDE), linear schedule:
|
91 |
+
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
92 |
+
|
93 |
+
"""
|
94 |
+
|
95 |
+
if schedule not in ['discrete', 'linear', 'cosine']:
|
96 |
+
raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
|
97 |
+
|
98 |
+
self.schedule = schedule
|
99 |
+
if schedule == 'discrete':
|
100 |
+
if betas is not None:
|
101 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
102 |
+
else:
|
103 |
+
assert alphas_cumprod is not None
|
104 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
105 |
+
self.total_N = len(log_alphas)
|
106 |
+
self.T = 1.
|
107 |
+
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
|
108 |
+
self.log_alpha_array = log_alphas.reshape((1, -1,))
|
109 |
+
else:
|
110 |
+
self.total_N = 1000
|
111 |
+
self.beta_0 = continuous_beta_0
|
112 |
+
self.beta_1 = continuous_beta_1
|
113 |
+
self.cosine_s = 0.008
|
114 |
+
self.cosine_beta_max = 999.
|
115 |
+
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
116 |
+
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
117 |
+
self.schedule = schedule
|
118 |
+
if schedule == 'cosine':
|
119 |
+
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
120 |
+
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
121 |
+
self.T = 0.9946
|
122 |
+
else:
|
123 |
+
self.T = 1.
|
124 |
+
|
125 |
+
def marginal_log_mean_coeff(self, t):
|
126 |
+
"""
|
127 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
128 |
+
"""
|
129 |
+
if self.schedule == 'discrete':
|
130 |
+
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
|
131 |
+
elif self.schedule == 'linear':
|
132 |
+
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
133 |
+
elif self.schedule == 'cosine':
|
134 |
+
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
135 |
+
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
136 |
+
return log_alpha_t
|
137 |
+
|
138 |
+
def marginal_alpha(self, t):
|
139 |
+
"""
|
140 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
141 |
+
"""
|
142 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
143 |
+
|
144 |
+
def marginal_std(self, t):
|
145 |
+
"""
|
146 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
147 |
+
"""
|
148 |
+
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
149 |
+
|
150 |
+
def marginal_lambda(self, t):
|
151 |
+
"""
|
152 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
153 |
+
"""
|
154 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
155 |
+
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
156 |
+
return log_mean_coeff - log_std
|
157 |
+
|
158 |
+
def inverse_lambda(self, lamb):
|
159 |
+
"""
|
160 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
161 |
+
"""
|
162 |
+
if self.schedule == 'linear':
|
163 |
+
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
164 |
+
Delta = self.beta_0**2 + tmp
|
165 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
166 |
+
elif self.schedule == 'discrete':
|
167 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
168 |
+
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
|
169 |
+
return t.reshape((-1,))
|
170 |
+
else:
|
171 |
+
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
172 |
+
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
173 |
+
t = t_fn(log_alpha)
|
174 |
+
return t
|
175 |
+
|
176 |
+
|
177 |
+
def model_wrapper(
|
178 |
+
model,
|
179 |
+
noise_schedule,
|
180 |
+
model_type="noise",
|
181 |
+
model_kwargs=None,
|
182 |
+
guidance_type="uncond",
|
183 |
+
#condition=None,
|
184 |
+
#unconditional_condition=None,
|
185 |
+
guidance_scale=1.,
|
186 |
+
classifier_fn=None,
|
187 |
+
classifier_kwargs=None,
|
188 |
+
):
|
189 |
+
"""Create a wrapper function for the noise prediction model.
|
190 |
+
|
191 |
+
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
192 |
+
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
193 |
+
|
194 |
+
We support four types of the diffusion model by setting `model_type`:
|
195 |
+
|
196 |
+
1. "noise": noise prediction model. (Trained by predicting noise).
|
197 |
+
|
198 |
+
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
199 |
+
|
200 |
+
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
201 |
+
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
202 |
+
|
203 |
+
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
204 |
+
arXiv preprint arXiv:2202.00512 (2022).
|
205 |
+
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
206 |
+
arXiv preprint arXiv:2210.02303 (2022).
|
207 |
+
|
208 |
+
4. "score": marginal score function. (Trained by denoising score matching).
|
209 |
+
Note that the score function and the noise prediction model follows a simple relationship:
|
210 |
+
```
|
211 |
+
noise(x_t, t) = -sigma_t * score(x_t, t)
|
212 |
+
```
|
213 |
+
|
214 |
+
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
215 |
+
1. "uncond": unconditional sampling by DPMs.
|
216 |
+
The input `model` has the following format:
|
217 |
+
``
|
218 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
219 |
+
``
|
220 |
+
|
221 |
+
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
222 |
+
The input `model` has the following format:
|
223 |
+
``
|
224 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
225 |
+
``
|
226 |
+
|
227 |
+
The input `classifier_fn` has the following format:
|
228 |
+
``
|
229 |
+
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
230 |
+
``
|
231 |
+
|
232 |
+
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
233 |
+
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
234 |
+
|
235 |
+
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
236 |
+
The input `model` has the following format:
|
237 |
+
``
|
238 |
+
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
239 |
+
``
|
240 |
+
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
241 |
+
|
242 |
+
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
243 |
+
arXiv preprint arXiv:2207.12598 (2022).
|
244 |
+
|
245 |
+
|
246 |
+
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
247 |
+
or continuous-time labels (i.e. epsilon to T).
|
248 |
+
|
249 |
+
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
250 |
+
``
|
251 |
+
def model_fn(x, t_continuous) -> noise:
|
252 |
+
t_input = get_model_input_time(t_continuous)
|
253 |
+
return noise_pred(model, x, t_input, **model_kwargs)
|
254 |
+
``
|
255 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
256 |
+
|
257 |
+
===============================================================
|
258 |
+
|
259 |
+
Args:
|
260 |
+
model: A diffusion model with the corresponding format described above.
|
261 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
262 |
+
model_type: A `str`. The parameterization type of the diffusion model.
|
263 |
+
"noise" or "x_start" or "v" or "score".
|
264 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
265 |
+
guidance_type: A `str`. The type of the guidance for sampling.
|
266 |
+
"uncond" or "classifier" or "classifier-free".
|
267 |
+
condition: A pytorch tensor. The condition for the guided sampling.
|
268 |
+
Only used for "classifier" or "classifier-free" guidance type.
|
269 |
+
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
270 |
+
Only used for "classifier-free" guidance type.
|
271 |
+
guidance_scale: A `float`. The scale for the guided sampling.
|
272 |
+
classifier_fn: A classifier function. Only used for the classifier guidance.
|
273 |
+
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
274 |
+
Returns:
|
275 |
+
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
276 |
+
"""
|
277 |
+
|
278 |
+
model_kwargs = model_kwargs or {}
|
279 |
+
classifier_kwargs = classifier_kwargs or {}
|
280 |
+
|
281 |
+
def get_model_input_time(t_continuous):
|
282 |
+
"""
|
283 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
284 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
285 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
286 |
+
"""
|
287 |
+
if noise_schedule.schedule == 'discrete':
|
288 |
+
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
|
289 |
+
else:
|
290 |
+
return t_continuous
|
291 |
+
|
292 |
+
def noise_pred_fn(x, t_continuous, cond=None):
|
293 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
294 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
295 |
+
t_input = get_model_input_time(t_continuous)
|
296 |
+
if cond is None:
|
297 |
+
output = model(x, t_input, None, **model_kwargs)
|
298 |
+
else:
|
299 |
+
output = model(x, t_input, cond, **model_kwargs)
|
300 |
+
if model_type == "noise":
|
301 |
+
return output
|
302 |
+
elif model_type == "x_start":
|
303 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
304 |
+
dims = x.dim()
|
305 |
+
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
|
306 |
+
elif model_type == "v":
|
307 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
308 |
+
dims = x.dim()
|
309 |
+
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
310 |
+
elif model_type == "score":
|
311 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
312 |
+
dims = x.dim()
|
313 |
+
return -expand_dims(sigma_t, dims) * output
|
314 |
+
|
315 |
+
def cond_grad_fn(x, t_input, condition):
|
316 |
+
"""
|
317 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
318 |
+
"""
|
319 |
+
with torch.enable_grad():
|
320 |
+
x_in = x.detach().requires_grad_(True)
|
321 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
322 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
323 |
+
|
324 |
+
def model_fn(x, t_continuous, condition, unconditional_condition):
|
325 |
+
"""
|
326 |
+
The noise predicition model function that is used for DPM-Solver.
|
327 |
+
"""
|
328 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
329 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
330 |
+
if guidance_type == "uncond":
|
331 |
+
return noise_pred_fn(x, t_continuous)
|
332 |
+
elif guidance_type == "classifier":
|
333 |
+
assert classifier_fn is not None
|
334 |
+
t_input = get_model_input_time(t_continuous)
|
335 |
+
cond_grad = cond_grad_fn(x, t_input, condition)
|
336 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
337 |
+
noise = noise_pred_fn(x, t_continuous)
|
338 |
+
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
|
339 |
+
elif guidance_type == "classifier-free":
|
340 |
+
if guidance_scale == 1. or unconditional_condition is None:
|
341 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
342 |
+
else:
|
343 |
+
x_in = torch.cat([x] * 2)
|
344 |
+
t_in = torch.cat([t_continuous] * 2)
|
345 |
+
if isinstance(condition, dict):
|
346 |
+
assert isinstance(unconditional_condition, dict)
|
347 |
+
c_in = {}
|
348 |
+
for k in condition:
|
349 |
+
if isinstance(condition[k], list):
|
350 |
+
c_in[k] = [torch.cat([
|
351 |
+
unconditional_condition[k][i],
|
352 |
+
condition[k][i]]) for i in range(len(condition[k]))]
|
353 |
+
else:
|
354 |
+
c_in[k] = torch.cat([
|
355 |
+
unconditional_condition[k],
|
356 |
+
condition[k]])
|
357 |
+
elif isinstance(condition, list):
|
358 |
+
c_in = []
|
359 |
+
assert isinstance(unconditional_condition, list)
|
360 |
+
for i in range(len(condition)):
|
361 |
+
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
362 |
+
else:
|
363 |
+
c_in = torch.cat([unconditional_condition, condition])
|
364 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
365 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
366 |
+
|
367 |
+
assert model_type in ["noise", "x_start", "v"]
|
368 |
+
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
369 |
+
return model_fn
|
370 |
+
|
371 |
+
|
372 |
+
class UniPC:
|
373 |
+
def __init__(
|
374 |
+
self,
|
375 |
+
model_fn,
|
376 |
+
noise_schedule,
|
377 |
+
predict_x0=True,
|
378 |
+
thresholding=False,
|
379 |
+
max_val=1.,
|
380 |
+
variant='bh1',
|
381 |
+
condition=None,
|
382 |
+
unconditional_condition=None,
|
383 |
+
before_sample=None,
|
384 |
+
after_sample=None,
|
385 |
+
after_update=None
|
386 |
+
):
|
387 |
+
"""Construct a UniPC.
|
388 |
+
|
389 |
+
We support both data_prediction and noise_prediction.
|
390 |
+
"""
|
391 |
+
self.model_fn_ = model_fn
|
392 |
+
self.noise_schedule = noise_schedule
|
393 |
+
self.variant = variant
|
394 |
+
self.predict_x0 = predict_x0
|
395 |
+
self.thresholding = thresholding
|
396 |
+
self.max_val = max_val
|
397 |
+
self.condition = condition
|
398 |
+
self.unconditional_condition = unconditional_condition
|
399 |
+
self.before_sample = before_sample
|
400 |
+
self.after_sample = after_sample
|
401 |
+
self.after_update = after_update
|
402 |
+
|
403 |
+
def dynamic_thresholding_fn(self, x0, t=None):
|
404 |
+
"""
|
405 |
+
The dynamic thresholding method.
|
406 |
+
"""
|
407 |
+
dims = x0.dim()
|
408 |
+
p = self.dynamic_thresholding_ratio
|
409 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
410 |
+
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
411 |
+
x0 = torch.clamp(x0, -s, s) / s
|
412 |
+
return x0
|
413 |
+
|
414 |
+
def model(self, x, t):
|
415 |
+
cond = self.condition
|
416 |
+
uncond = self.unconditional_condition
|
417 |
+
if self.before_sample is not None:
|
418 |
+
x, t, cond, uncond = self.before_sample(x, t, cond, uncond)
|
419 |
+
res = self.model_fn_(x, t, cond, uncond)
|
420 |
+
if self.after_sample is not None:
|
421 |
+
x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res)
|
422 |
+
|
423 |
+
if isinstance(res, tuple):
|
424 |
+
# (None, pred_x0)
|
425 |
+
res = res[1]
|
426 |
+
|
427 |
+
return res
|
428 |
+
|
429 |
+
def noise_prediction_fn(self, x, t):
|
430 |
+
"""
|
431 |
+
Return the noise prediction model.
|
432 |
+
"""
|
433 |
+
return self.model(x, t)
|
434 |
+
|
435 |
+
def data_prediction_fn(self, x, t):
|
436 |
+
"""
|
437 |
+
Return the data prediction model (with thresholding).
|
438 |
+
"""
|
439 |
+
noise = self.noise_prediction_fn(x, t)
|
440 |
+
dims = x.dim()
|
441 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
442 |
+
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
443 |
+
if self.thresholding:
|
444 |
+
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
445 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
446 |
+
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
447 |
+
x0 = torch.clamp(x0, -s, s) / s
|
448 |
+
return x0
|
449 |
+
|
450 |
+
def model_fn(self, x, t):
|
451 |
+
"""
|
452 |
+
Convert the model to the noise prediction model or the data prediction model.
|
453 |
+
"""
|
454 |
+
if self.predict_x0:
|
455 |
+
return self.data_prediction_fn(x, t)
|
456 |
+
else:
|
457 |
+
return self.noise_prediction_fn(x, t)
|
458 |
+
|
459 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
460 |
+
"""Compute the intermediate time steps for sampling.
|
461 |
+
"""
|
462 |
+
if skip_type == 'logSNR':
|
463 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
464 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
465 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
466 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
467 |
+
elif skip_type == 'time_uniform':
|
468 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
469 |
+
elif skip_type == 'time_quadratic':
|
470 |
+
t_order = 2
|
471 |
+
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
472 |
+
return t
|
473 |
+
else:
|
474 |
+
raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
|
475 |
+
|
476 |
+
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
477 |
+
"""
|
478 |
+
Get the order of each step for sampling by the singlestep DPM-Solver.
|
479 |
+
"""
|
480 |
+
if order == 3:
|
481 |
+
K = steps // 3 + 1
|
482 |
+
if steps % 3 == 0:
|
483 |
+
orders = [3,] * (K - 2) + [2, 1]
|
484 |
+
elif steps % 3 == 1:
|
485 |
+
orders = [3,] * (K - 1) + [1]
|
486 |
+
else:
|
487 |
+
orders = [3,] * (K - 1) + [2]
|
488 |
+
elif order == 2:
|
489 |
+
if steps % 2 == 0:
|
490 |
+
K = steps // 2
|
491 |
+
orders = [2,] * K
|
492 |
+
else:
|
493 |
+
K = steps // 2 + 1
|
494 |
+
orders = [2,] * (K - 1) + [1]
|
495 |
+
elif order == 1:
|
496 |
+
K = steps
|
497 |
+
orders = [1,] * steps
|
498 |
+
else:
|
499 |
+
raise ValueError("'order' must be '1' or '2' or '3'.")
|
500 |
+
if skip_type == 'logSNR':
|
501 |
+
# To reproduce the results in DPM-Solver paper
|
502 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
503 |
+
else:
|
504 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
|
505 |
+
return timesteps_outer, orders
|
506 |
+
|
507 |
+
def denoise_to_zero_fn(self, x, s):
|
508 |
+
"""
|
509 |
+
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
510 |
+
"""
|
511 |
+
return self.data_prediction_fn(x, s)
|
512 |
+
|
513 |
+
def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
|
514 |
+
if len(t.shape) == 0:
|
515 |
+
t = t.view(-1)
|
516 |
+
if 'bh' in self.variant:
|
517 |
+
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
518 |
+
else:
|
519 |
+
assert self.variant == 'vary_coeff'
|
520 |
+
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
521 |
+
|
522 |
+
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
523 |
+
#print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
524 |
+
ns = self.noise_schedule
|
525 |
+
assert order <= len(model_prev_list)
|
526 |
+
|
527 |
+
# first compute rks
|
528 |
+
t_prev_0 = t_prev_list[-1]
|
529 |
+
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
530 |
+
lambda_t = ns.marginal_lambda(t)
|
531 |
+
model_prev_0 = model_prev_list[-1]
|
532 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
533 |
+
log_alpha_t = ns.marginal_log_mean_coeff(t)
|
534 |
+
alpha_t = torch.exp(log_alpha_t)
|
535 |
+
|
536 |
+
h = lambda_t - lambda_prev_0
|
537 |
+
|
538 |
+
rks = []
|
539 |
+
D1s = []
|
540 |
+
for i in range(1, order):
|
541 |
+
t_prev_i = t_prev_list[-(i + 1)]
|
542 |
+
model_prev_i = model_prev_list[-(i + 1)]
|
543 |
+
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
544 |
+
rk = (lambda_prev_i - lambda_prev_0) / h
|
545 |
+
rks.append(rk)
|
546 |
+
D1s.append((model_prev_i - model_prev_0) / rk)
|
547 |
+
|
548 |
+
rks.append(1.)
|
549 |
+
rks = torch.tensor(rks, device=x.device)
|
550 |
+
|
551 |
+
K = len(rks)
|
552 |
+
# build C matrix
|
553 |
+
C = []
|
554 |
+
|
555 |
+
col = torch.ones_like(rks)
|
556 |
+
for k in range(1, K + 1):
|
557 |
+
C.append(col)
|
558 |
+
col = col * rks / (k + 1)
|
559 |
+
C = torch.stack(C, dim=1)
|
560 |
+
|
561 |
+
if len(D1s) > 0:
|
562 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
563 |
+
C_inv_p = torch.linalg.inv(C[:-1, :-1])
|
564 |
+
A_p = C_inv_p
|
565 |
+
|
566 |
+
if use_corrector:
|
567 |
+
#print('using corrector')
|
568 |
+
C_inv = torch.linalg.inv(C)
|
569 |
+
A_c = C_inv
|
570 |
+
|
571 |
+
hh = -h if self.predict_x0 else h
|
572 |
+
h_phi_1 = torch.expm1(hh)
|
573 |
+
h_phi_ks = []
|
574 |
+
factorial_k = 1
|
575 |
+
h_phi_k = h_phi_1
|
576 |
+
for k in range(1, K + 2):
|
577 |
+
h_phi_ks.append(h_phi_k)
|
578 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_k
|
579 |
+
factorial_k *= (k + 1)
|
580 |
+
|
581 |
+
model_t = None
|
582 |
+
if self.predict_x0:
|
583 |
+
x_t_ = (
|
584 |
+
sigma_t / sigma_prev_0 * x
|
585 |
+
- alpha_t * h_phi_1 * model_prev_0
|
586 |
+
)
|
587 |
+
# now predictor
|
588 |
+
x_t = x_t_
|
589 |
+
if len(D1s) > 0:
|
590 |
+
# compute the residuals for predictor
|
591 |
+
for k in range(K - 1):
|
592 |
+
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
593 |
+
# now corrector
|
594 |
+
if use_corrector:
|
595 |
+
model_t = self.model_fn(x_t, t)
|
596 |
+
D1_t = (model_t - model_prev_0)
|
597 |
+
x_t = x_t_
|
598 |
+
k = 0
|
599 |
+
for k in range(K - 1):
|
600 |
+
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
601 |
+
x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
602 |
+
else:
|
603 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
604 |
+
x_t_ = (
|
605 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
606 |
+
- (sigma_t * h_phi_1) * model_prev_0
|
607 |
+
)
|
608 |
+
# now predictor
|
609 |
+
x_t = x_t_
|
610 |
+
if len(D1s) > 0:
|
611 |
+
# compute the residuals for predictor
|
612 |
+
for k in range(K - 1):
|
613 |
+
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
614 |
+
# now corrector
|
615 |
+
if use_corrector:
|
616 |
+
model_t = self.model_fn(x_t, t)
|
617 |
+
D1_t = (model_t - model_prev_0)
|
618 |
+
x_t = x_t_
|
619 |
+
k = 0
|
620 |
+
for k in range(K - 1):
|
621 |
+
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
622 |
+
x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
623 |
+
return x_t, model_t
|
624 |
+
|
625 |
+
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
626 |
+
#print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
627 |
+
ns = self.noise_schedule
|
628 |
+
assert order <= len(model_prev_list)
|
629 |
+
dims = x.dim()
|
630 |
+
|
631 |
+
# first compute rks
|
632 |
+
t_prev_0 = t_prev_list[-1]
|
633 |
+
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
634 |
+
lambda_t = ns.marginal_lambda(t)
|
635 |
+
model_prev_0 = model_prev_list[-1]
|
636 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
637 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
638 |
+
alpha_t = torch.exp(log_alpha_t)
|
639 |
+
|
640 |
+
h = lambda_t - lambda_prev_0
|
641 |
+
|
642 |
+
rks = []
|
643 |
+
D1s = []
|
644 |
+
for i in range(1, order):
|
645 |
+
t_prev_i = t_prev_list[-(i + 1)]
|
646 |
+
model_prev_i = model_prev_list[-(i + 1)]
|
647 |
+
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
648 |
+
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
649 |
+
rks.append(rk)
|
650 |
+
D1s.append((model_prev_i - model_prev_0) / rk)
|
651 |
+
|
652 |
+
rks.append(1.)
|
653 |
+
rks = torch.tensor(rks, device=x.device)
|
654 |
+
|
655 |
+
R = []
|
656 |
+
b = []
|
657 |
+
|
658 |
+
hh = -h[0] if self.predict_x0 else h[0]
|
659 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
660 |
+
h_phi_k = h_phi_1 / hh - 1
|
661 |
+
|
662 |
+
factorial_i = 1
|
663 |
+
|
664 |
+
if self.variant == 'bh1':
|
665 |
+
B_h = hh
|
666 |
+
elif self.variant == 'bh2':
|
667 |
+
B_h = torch.expm1(hh)
|
668 |
+
else:
|
669 |
+
raise NotImplementedError()
|
670 |
+
|
671 |
+
for i in range(1, order + 1):
|
672 |
+
R.append(torch.pow(rks, i - 1))
|
673 |
+
b.append(h_phi_k * factorial_i / B_h)
|
674 |
+
factorial_i *= (i + 1)
|
675 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
676 |
+
|
677 |
+
R = torch.stack(R)
|
678 |
+
b = torch.tensor(b, device=x.device)
|
679 |
+
|
680 |
+
# now predictor
|
681 |
+
use_predictor = len(D1s) > 0 and x_t is None
|
682 |
+
if len(D1s) > 0:
|
683 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
684 |
+
if x_t is None:
|
685 |
+
# for order 2, we use a simplified version
|
686 |
+
if order == 2:
|
687 |
+
rhos_p = torch.tensor([0.5], device=b.device)
|
688 |
+
else:
|
689 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
690 |
+
else:
|
691 |
+
D1s = None
|
692 |
+
|
693 |
+
if use_corrector:
|
694 |
+
#print('using corrector')
|
695 |
+
# for order 1, we use a simplified version
|
696 |
+
if order == 1:
|
697 |
+
rhos_c = torch.tensor([0.5], device=b.device)
|
698 |
+
else:
|
699 |
+
rhos_c = torch.linalg.solve(R, b)
|
700 |
+
|
701 |
+
model_t = None
|
702 |
+
if self.predict_x0:
|
703 |
+
x_t_ = (
|
704 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
705 |
+
- expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
|
706 |
+
)
|
707 |
+
|
708 |
+
if x_t is None:
|
709 |
+
if use_predictor:
|
710 |
+
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
711 |
+
else:
|
712 |
+
pred_res = 0
|
713 |
+
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
714 |
+
|
715 |
+
if use_corrector:
|
716 |
+
model_t = self.model_fn(x_t, t)
|
717 |
+
if D1s is not None:
|
718 |
+
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
719 |
+
else:
|
720 |
+
corr_res = 0
|
721 |
+
D1_t = (model_t - model_prev_0)
|
722 |
+
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
723 |
+
else:
|
724 |
+
x_t_ = (
|
725 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
726 |
+
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
727 |
+
)
|
728 |
+
if x_t is None:
|
729 |
+
if use_predictor:
|
730 |
+
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
731 |
+
else:
|
732 |
+
pred_res = 0
|
733 |
+
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
|
734 |
+
|
735 |
+
if use_corrector:
|
736 |
+
model_t = self.model_fn(x_t, t)
|
737 |
+
if D1s is not None:
|
738 |
+
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
739 |
+
else:
|
740 |
+
corr_res = 0
|
741 |
+
D1_t = (model_t - model_prev_0)
|
742 |
+
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
743 |
+
return x_t, model_t
|
744 |
+
|
745 |
+
|
746 |
+
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
747 |
+
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
748 |
+
atol=0.0078, rtol=0.05, corrector=False,
|
749 |
+
):
|
750 |
+
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
751 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
752 |
+
device = x.device
|
753 |
+
if method == 'multistep':
|
754 |
+
assert steps >= order, "UniPC order must be < sampling steps"
|
755 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
756 |
+
#print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
|
757 |
+
assert timesteps.shape[0] - 1 == steps
|
758 |
+
with torch.no_grad():
|
759 |
+
vec_t = timesteps[0].expand((x.shape[0]))
|
760 |
+
model_prev_list = [self.model_fn(x, vec_t)]
|
761 |
+
t_prev_list = [vec_t]
|
762 |
+
with tqdm.tqdm(total=steps) as pbar:
|
763 |
+
# Init the first `order` values by lower order multistep DPM-Solver.
|
764 |
+
for init_order in range(1, order):
|
765 |
+
vec_t = timesteps[init_order].expand(x.shape[0])
|
766 |
+
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
767 |
+
if model_x is None:
|
768 |
+
model_x = self.model_fn(x, vec_t)
|
769 |
+
if self.after_update is not None:
|
770 |
+
self.after_update(x, model_x)
|
771 |
+
model_prev_list.append(model_x)
|
772 |
+
t_prev_list.append(vec_t)
|
773 |
+
pbar.update()
|
774 |
+
|
775 |
+
for step in range(order, steps + 1):
|
776 |
+
vec_t = timesteps[step].expand(x.shape[0])
|
777 |
+
if lower_order_final:
|
778 |
+
step_order = min(order, steps + 1 - step)
|
779 |
+
else:
|
780 |
+
step_order = order
|
781 |
+
#print('this step order:', step_order)
|
782 |
+
if step == steps:
|
783 |
+
#print('do not run corrector at the last step')
|
784 |
+
use_corrector = False
|
785 |
+
else:
|
786 |
+
use_corrector = True
|
787 |
+
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
788 |
+
if self.after_update is not None:
|
789 |
+
self.after_update(x, model_x)
|
790 |
+
for i in range(order - 1):
|
791 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
792 |
+
model_prev_list[i] = model_prev_list[i + 1]
|
793 |
+
t_prev_list[-1] = vec_t
|
794 |
+
# We do not need to evaluate the final model value.
|
795 |
+
if step < steps:
|
796 |
+
if model_x is None:
|
797 |
+
model_x = self.model_fn(x, vec_t)
|
798 |
+
model_prev_list[-1] = model_x
|
799 |
+
pbar.update()
|
800 |
+
else:
|
801 |
+
raise NotImplementedError()
|
802 |
+
if denoise_to_zero:
|
803 |
+
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
804 |
+
return x
|
805 |
+
|
806 |
+
|
807 |
+
#############################################################
|
808 |
+
# other utility functions
|
809 |
+
#############################################################
|
810 |
+
|
811 |
+
def interpolate_fn(x, xp, yp):
|
812 |
+
"""
|
813 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
814 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
815 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
816 |
+
|
817 |
+
Args:
|
818 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
819 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
820 |
+
yp: PyTorch tensor with shape [C, K].
|
821 |
+
Returns:
|
822 |
+
The function values f(x), with shape [N, C].
|
823 |
+
"""
|
824 |
+
N, K = x.shape[0], xp.shape[1]
|
825 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
826 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
827 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
828 |
+
cand_start_idx = x_idx - 1
|
829 |
+
start_idx = torch.where(
|
830 |
+
torch.eq(x_idx, 0),
|
831 |
+
torch.tensor(1, device=x.device),
|
832 |
+
torch.where(
|
833 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
834 |
+
),
|
835 |
+
)
|
836 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
837 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
838 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
839 |
+
start_idx2 = torch.where(
|
840 |
+
torch.eq(x_idx, 0),
|
841 |
+
torch.tensor(0, device=x.device),
|
842 |
+
torch.where(
|
843 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
844 |
+
),
|
845 |
+
)
|
846 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
847 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
848 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
849 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
850 |
+
return cand
|
851 |
+
|
852 |
+
|
853 |
+
def expand_dims(v, dims):
|
854 |
+
"""
|
855 |
+
Expand the tensor `v` to the dim `dims`.
|
856 |
+
|
857 |
+
Args:
|
858 |
+
`v`: a PyTorch tensor with shape [N].
|
859 |
+
`dim`: a `int`.
|
860 |
+
Returns:
|
861 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
862 |
+
"""
|
863 |
+
return v[(...,) + (None,)*(dims - 1)]
|
modules/ngrok.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ngrok
|
2 |
+
|
3 |
+
# Connect to ngrok for ingress
|
4 |
+
def connect(token, port, options):
|
5 |
+
account = None
|
6 |
+
if token is None:
|
7 |
+
token = 'None'
|
8 |
+
else:
|
9 |
+
if ':' in token:
|
10 |
+
# token = authtoken:username:password
|
11 |
+
token, username, password = token.split(':', 2)
|
12 |
+
account = f"{username}:{password}"
|
13 |
+
|
14 |
+
# For all options see: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py
|
15 |
+
if not options.get('authtoken_from_env'):
|
16 |
+
options['authtoken'] = token
|
17 |
+
if account:
|
18 |
+
options['basic_auth'] = account
|
19 |
+
if not options.get('session_metadata'):
|
20 |
+
options['session_metadata'] = 'stable-diffusion-webui'
|
21 |
+
|
22 |
+
|
23 |
+
try:
|
24 |
+
public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url()
|
25 |
+
except Exception as e:
|
26 |
+
print(f'Invalid ngrok authtoken? ngrok connection aborted due to: {e}\n'
|
27 |
+
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
|
28 |
+
else:
|
29 |
+
print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
|
30 |
+
'You can use this link after the launch is complete.')
|
modules/paths.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
|
4 |
+
|
5 |
+
import modules.safe # noqa: F401
|
6 |
+
|
7 |
+
|
8 |
+
def mute_sdxl_imports():
|
9 |
+
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
|
10 |
+
|
11 |
+
class Dummy:
|
12 |
+
pass
|
13 |
+
|
14 |
+
module = Dummy()
|
15 |
+
module.LPIPS = None
|
16 |
+
sys.modules['taming.modules.losses.lpips'] = module
|
17 |
+
|
18 |
+
module = Dummy()
|
19 |
+
module.StableDataModuleFromConfig = None
|
20 |
+
sys.modules['sgm.data'] = module
|
21 |
+
|
22 |
+
|
23 |
+
# data_path = cmd_opts_pre.data
|
24 |
+
sys.path.insert(0, script_path)
|
25 |
+
|
26 |
+
# search for directory of stable diffusion in following places
|
27 |
+
sd_path = None
|
28 |
+
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
|
29 |
+
for possible_sd_path in possible_sd_paths:
|
30 |
+
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
31 |
+
sd_path = os.path.abspath(possible_sd_path)
|
32 |
+
break
|
33 |
+
|
34 |
+
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
|
35 |
+
|
36 |
+
mute_sdxl_imports()
|
37 |
+
|
38 |
+
path_dirs = [
|
39 |
+
(sd_path, 'ldm', 'Stable Diffusion', []),
|
40 |
+
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
41 |
+
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
42 |
+
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
43 |
+
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
44 |
+
]
|
45 |
+
|
46 |
+
paths = {}
|
47 |
+
|
48 |
+
for d, must_exist, what, options in path_dirs:
|
49 |
+
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
|
50 |
+
if not os.path.exists(must_exist_path):
|
51 |
+
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
|
52 |
+
else:
|
53 |
+
d = os.path.abspath(d)
|
54 |
+
if "atstart" in options:
|
55 |
+
sys.path.insert(0, d)
|
56 |
+
elif "sgm" in options:
|
57 |
+
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
|
58 |
+
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
|
59 |
+
|
60 |
+
sys.path.insert(0, d)
|
61 |
+
import sgm # noqa: F401
|
62 |
+
sys.path.pop(0)
|
63 |
+
else:
|
64 |
+
sys.path.append(d)
|
65 |
+
paths[what] = d
|
modules/paths_internal.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py"""
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import shlex
|
7 |
+
|
8 |
+
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
9 |
+
sys.argv += shlex.split(commandline_args)
|
10 |
+
|
11 |
+
modules_path = os.path.dirname(os.path.realpath(__file__))
|
12 |
+
script_path = os.path.dirname(modules_path)
|
13 |
+
|
14 |
+
sd_configs_path = os.path.join(script_path, "configs")
|
15 |
+
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
16 |
+
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
17 |
+
default_sd_model_file = sd_model_file
|
18 |
+
|
19 |
+
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
20 |
+
parser_pre = argparse.ArgumentParser(add_help=False)
|
21 |
+
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
|
22 |
+
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
23 |
+
|
24 |
+
data_path = cmd_opts_pre.data_dir
|
25 |
+
|
26 |
+
models_path = os.path.join(data_path, "models")
|
27 |
+
extensions_dir = os.path.join(data_path, "extensions")
|
28 |
+
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
29 |
+
config_states_dir = os.path.join(script_path, "config_states")
|
30 |
+
|
31 |
+
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
modules/postprocessing.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
|
6 |
+
from modules.shared import opts
|
7 |
+
|
8 |
+
|
9 |
+
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
10 |
+
devices.torch_gc()
|
11 |
+
|
12 |
+
shared.state.begin(job="extras")
|
13 |
+
|
14 |
+
image_data = []
|
15 |
+
image_names = []
|
16 |
+
outputs = []
|
17 |
+
|
18 |
+
if extras_mode == 1:
|
19 |
+
for img in image_folder:
|
20 |
+
if isinstance(img, Image.Image):
|
21 |
+
image = img
|
22 |
+
fn = ''
|
23 |
+
else:
|
24 |
+
image = Image.open(os.path.abspath(img.name))
|
25 |
+
fn = os.path.splitext(img.orig_name)[0]
|
26 |
+
image_data.append(image)
|
27 |
+
image_names.append(fn)
|
28 |
+
elif extras_mode == 2:
|
29 |
+
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
30 |
+
assert input_dir, 'input directory not selected'
|
31 |
+
|
32 |
+
image_list = shared.listfiles(input_dir)
|
33 |
+
for filename in image_list:
|
34 |
+
try:
|
35 |
+
image = Image.open(filename)
|
36 |
+
except Exception:
|
37 |
+
continue
|
38 |
+
image_data.append(image)
|
39 |
+
image_names.append(filename)
|
40 |
+
else:
|
41 |
+
assert image, 'image not selected'
|
42 |
+
|
43 |
+
image_data.append(image)
|
44 |
+
image_names.append(None)
|
45 |
+
|
46 |
+
if extras_mode == 2 and output_dir != '':
|
47 |
+
outpath = output_dir
|
48 |
+
else:
|
49 |
+
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
50 |
+
|
51 |
+
infotext = ''
|
52 |
+
|
53 |
+
for image, name in zip(image_data, image_names):
|
54 |
+
shared.state.textinfo = name
|
55 |
+
|
56 |
+
parameters, existing_pnginfo = images.read_info_from_image(image)
|
57 |
+
if parameters:
|
58 |
+
existing_pnginfo["parameters"] = parameters
|
59 |
+
|
60 |
+
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
61 |
+
|
62 |
+
scripts.scripts_postproc.run(pp, args)
|
63 |
+
|
64 |
+
if opts.use_original_name_batch and name is not None:
|
65 |
+
basename = os.path.splitext(os.path.basename(name))[0]
|
66 |
+
else:
|
67 |
+
basename = ''
|
68 |
+
|
69 |
+
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
|
70 |
+
|
71 |
+
if opts.enable_pnginfo:
|
72 |
+
pp.image.info = existing_pnginfo
|
73 |
+
pp.image.info["postprocessing"] = infotext
|
74 |
+
|
75 |
+
if save_output:
|
76 |
+
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
|
77 |
+
|
78 |
+
if extras_mode != 2 or show_extras_results:
|
79 |
+
outputs.append(pp.image)
|
80 |
+
|
81 |
+
devices.torch_gc()
|
82 |
+
|
83 |
+
return outputs, ui_common.plaintext_to_html(infotext), ''
|
84 |
+
|
85 |
+
|
86 |
+
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
87 |
+
"""old handler for API"""
|
88 |
+
|
89 |
+
args = scripts.scripts_postproc.create_args_for_run({
|
90 |
+
"Upscale": {
|
91 |
+
"upscale_mode": resize_mode,
|
92 |
+
"upscale_by": upscaling_resize,
|
93 |
+
"upscale_to_width": upscaling_resize_w,
|
94 |
+
"upscale_to_height": upscaling_resize_h,
|
95 |
+
"upscale_crop": upscaling_crop,
|
96 |
+
"upscaler_1_name": extras_upscaler_1,
|
97 |
+
"upscaler_2_name": extras_upscaler_2,
|
98 |
+
"upscaler_2_visibility": extras_upscaler_2_visibility,
|
99 |
+
},
|
100 |
+
"GFPGAN": {
|
101 |
+
"gfpgan_visibility": gfpgan_visibility,
|
102 |
+
},
|
103 |
+
"CodeFormer": {
|
104 |
+
"codeformer_visibility": codeformer_visibility,
|
105 |
+
"codeformer_weight": codeformer_weight,
|
106 |
+
},
|
107 |
+
})
|
108 |
+
|
109 |
+
return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
|
modules/processing.py
ADDED
@@ -0,0 +1,1405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import hashlib
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image, ImageOps
|
11 |
+
import random
|
12 |
+
import cv2
|
13 |
+
from skimage import exposure
|
14 |
+
from typing import Any, Dict, List
|
15 |
+
|
16 |
+
import modules.sd_hijack
|
17 |
+
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
|
18 |
+
from modules.sd_hijack import model_hijack
|
19 |
+
from modules.shared import opts, cmd_opts, state
|
20 |
+
import modules.shared as shared
|
21 |
+
import modules.paths as paths
|
22 |
+
import modules.face_restoration
|
23 |
+
import modules.images as images
|
24 |
+
import modules.styles
|
25 |
+
import modules.sd_models as sd_models
|
26 |
+
import modules.sd_vae as sd_vae
|
27 |
+
from ldm.data.util import AddMiDaS
|
28 |
+
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
29 |
+
|
30 |
+
from einops import repeat, rearrange
|
31 |
+
from blendmodes.blend import blendLayers, BlendType
|
32 |
+
|
33 |
+
|
34 |
+
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
35 |
+
opt_C = 4
|
36 |
+
opt_f = 8
|
37 |
+
|
38 |
+
|
39 |
+
def setup_color_correction(image):
|
40 |
+
logging.info("Calibrating color correction.")
|
41 |
+
correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
|
42 |
+
return correction_target
|
43 |
+
|
44 |
+
|
45 |
+
def apply_color_correction(correction, original_image):
|
46 |
+
logging.info("Applying color correction.")
|
47 |
+
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
48 |
+
cv2.cvtColor(
|
49 |
+
np.asarray(original_image),
|
50 |
+
cv2.COLOR_RGB2LAB
|
51 |
+
),
|
52 |
+
correction,
|
53 |
+
channel_axis=2
|
54 |
+
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
55 |
+
|
56 |
+
image = blendLayers(image, original_image, BlendType.LUMINOSITY)
|
57 |
+
|
58 |
+
return image
|
59 |
+
|
60 |
+
|
61 |
+
def apply_overlay(image, paste_loc, index, overlays):
|
62 |
+
if overlays is None or index >= len(overlays):
|
63 |
+
return image
|
64 |
+
|
65 |
+
overlay = overlays[index]
|
66 |
+
|
67 |
+
if paste_loc is not None:
|
68 |
+
x, y, w, h = paste_loc
|
69 |
+
base_image = Image.new('RGBA', (overlay.width, overlay.height))
|
70 |
+
image = images.resize_image(1, image, w, h)
|
71 |
+
base_image.paste(image, (x, y))
|
72 |
+
image = base_image
|
73 |
+
|
74 |
+
image = image.convert('RGBA')
|
75 |
+
image.alpha_composite(overlay)
|
76 |
+
image = image.convert('RGB')
|
77 |
+
|
78 |
+
return image
|
79 |
+
|
80 |
+
|
81 |
+
def txt2img_image_conditioning(sd_model, x, width, height):
|
82 |
+
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
|
83 |
+
|
84 |
+
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
85 |
+
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
86 |
+
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
|
87 |
+
|
88 |
+
# Add the fake full 1s mask to the first dimension.
|
89 |
+
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
90 |
+
image_conditioning = image_conditioning.to(x.dtype)
|
91 |
+
|
92 |
+
return image_conditioning
|
93 |
+
|
94 |
+
elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
|
95 |
+
|
96 |
+
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
|
97 |
+
|
98 |
+
else:
|
99 |
+
# Dummy zero conditioning if we're not using inpainting or unclip models.
|
100 |
+
# Still takes up a bit of memory, but no encoder call.
|
101 |
+
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
102 |
+
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
103 |
+
|
104 |
+
|
105 |
+
class StableDiffusionProcessing:
|
106 |
+
"""
|
107 |
+
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
108 |
+
"""
|
109 |
+
cached_uc = [None, None]
|
110 |
+
cached_c = [None, None]
|
111 |
+
|
112 |
+
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
113 |
+
if sampler_index is not None:
|
114 |
+
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
115 |
+
|
116 |
+
self.outpath_samples: str = outpath_samples
|
117 |
+
self.outpath_grids: str = outpath_grids
|
118 |
+
self.prompt: str = prompt
|
119 |
+
self.prompt_for_display: str = None
|
120 |
+
self.negative_prompt: str = (negative_prompt or "")
|
121 |
+
self.styles: list = styles or []
|
122 |
+
self.seed: int = seed
|
123 |
+
self.subseed: int = subseed
|
124 |
+
self.subseed_strength: float = subseed_strength
|
125 |
+
self.seed_resize_from_h: int = seed_resize_from_h
|
126 |
+
self.seed_resize_from_w: int = seed_resize_from_w
|
127 |
+
self.sampler_name: str = sampler_name
|
128 |
+
self.batch_size: int = batch_size
|
129 |
+
self.n_iter: int = n_iter
|
130 |
+
self.steps: int = steps
|
131 |
+
self.cfg_scale: float = cfg_scale
|
132 |
+
self.width: int = width
|
133 |
+
self.height: int = height
|
134 |
+
self.restore_faces: bool = restore_faces
|
135 |
+
self.tiling: bool = tiling
|
136 |
+
self.do_not_save_samples: bool = do_not_save_samples
|
137 |
+
self.do_not_save_grid: bool = do_not_save_grid
|
138 |
+
self.extra_generation_params: dict = extra_generation_params or {}
|
139 |
+
self.overlay_images = overlay_images
|
140 |
+
self.eta = eta
|
141 |
+
self.do_not_reload_embeddings = do_not_reload_embeddings
|
142 |
+
self.paste_to = None
|
143 |
+
self.color_corrections = None
|
144 |
+
self.denoising_strength: float = denoising_strength
|
145 |
+
self.sampler_noise_scheduler_override = None
|
146 |
+
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
147 |
+
self.s_min_uncond = s_min_uncond or opts.s_min_uncond
|
148 |
+
self.s_churn = s_churn or opts.s_churn
|
149 |
+
self.s_tmin = s_tmin or opts.s_tmin
|
150 |
+
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
151 |
+
self.s_noise = s_noise or opts.s_noise
|
152 |
+
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
153 |
+
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
154 |
+
self.is_using_inpainting_conditioning = False
|
155 |
+
self.disable_extra_networks = False
|
156 |
+
self.token_merging_ratio = 0
|
157 |
+
self.token_merging_ratio_hr = 0
|
158 |
+
|
159 |
+
if not seed_enable_extras:
|
160 |
+
self.subseed = -1
|
161 |
+
self.subseed_strength = 0
|
162 |
+
self.seed_resize_from_h = 0
|
163 |
+
self.seed_resize_from_w = 0
|
164 |
+
|
165 |
+
self.scripts = None
|
166 |
+
self.script_args = script_args
|
167 |
+
self.all_prompts = None
|
168 |
+
self.all_negative_prompts = None
|
169 |
+
self.all_seeds = None
|
170 |
+
self.all_subseeds = None
|
171 |
+
self.iteration = 0
|
172 |
+
self.is_hr_pass = False
|
173 |
+
self.sampler = None
|
174 |
+
|
175 |
+
self.prompts = None
|
176 |
+
self.negative_prompts = None
|
177 |
+
self.extra_network_data = None
|
178 |
+
self.seeds = None
|
179 |
+
self.subseeds = None
|
180 |
+
|
181 |
+
self.step_multiplier = 1
|
182 |
+
self.cached_uc = StableDiffusionProcessing.cached_uc
|
183 |
+
self.cached_c = StableDiffusionProcessing.cached_c
|
184 |
+
self.uc = None
|
185 |
+
self.c = None
|
186 |
+
|
187 |
+
self.user = None
|
188 |
+
|
189 |
+
@property
|
190 |
+
def sd_model(self):
|
191 |
+
return shared.sd_model
|
192 |
+
|
193 |
+
def txt2img_image_conditioning(self, x, width=None, height=None):
|
194 |
+
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
195 |
+
|
196 |
+
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
197 |
+
|
198 |
+
def depth2img_image_conditioning(self, source_image):
|
199 |
+
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
200 |
+
transformer = AddMiDaS(model_type="dpt_hybrid")
|
201 |
+
transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
|
202 |
+
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
203 |
+
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
204 |
+
|
205 |
+
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
|
206 |
+
conditioning = torch.nn.functional.interpolate(
|
207 |
+
self.sd_model.depth_model(midas_in),
|
208 |
+
size=conditioning_image.shape[2:],
|
209 |
+
mode="bicubic",
|
210 |
+
align_corners=False,
|
211 |
+
)
|
212 |
+
|
213 |
+
(depth_min, depth_max) = torch.aminmax(conditioning)
|
214 |
+
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
215 |
+
return conditioning
|
216 |
+
|
217 |
+
def edit_image_conditioning(self, source_image):
|
218 |
+
conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
|
219 |
+
|
220 |
+
return conditioning_image
|
221 |
+
|
222 |
+
def unclip_image_conditioning(self, source_image):
|
223 |
+
c_adm = self.sd_model.embedder(source_image)
|
224 |
+
if self.sd_model.noise_augmentor is not None:
|
225 |
+
noise_level = 0 # TODO: Allow other noise levels?
|
226 |
+
c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
|
227 |
+
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
228 |
+
return c_adm
|
229 |
+
|
230 |
+
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
231 |
+
self.is_using_inpainting_conditioning = True
|
232 |
+
|
233 |
+
# Handle the different mask inputs
|
234 |
+
if image_mask is not None:
|
235 |
+
if torch.is_tensor(image_mask):
|
236 |
+
conditioning_mask = image_mask
|
237 |
+
else:
|
238 |
+
conditioning_mask = np.array(image_mask.convert("L"))
|
239 |
+
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
240 |
+
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
241 |
+
|
242 |
+
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
243 |
+
conditioning_mask = torch.round(conditioning_mask)
|
244 |
+
else:
|
245 |
+
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
|
246 |
+
|
247 |
+
# Create another latent image, this time with a masked version of the original input.
|
248 |
+
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
|
249 |
+
conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
|
250 |
+
conditioning_image = torch.lerp(
|
251 |
+
source_image,
|
252 |
+
source_image * (1.0 - conditioning_mask),
|
253 |
+
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
|
254 |
+
)
|
255 |
+
|
256 |
+
# Encode the new masked image using first stage of network.
|
257 |
+
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
258 |
+
|
259 |
+
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
260 |
+
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
261 |
+
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
262 |
+
image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
263 |
+
image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
264 |
+
|
265 |
+
return image_conditioning
|
266 |
+
|
267 |
+
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
268 |
+
source_image = devices.cond_cast_float(source_image)
|
269 |
+
|
270 |
+
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
271 |
+
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
272 |
+
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
273 |
+
return self.depth2img_image_conditioning(source_image)
|
274 |
+
|
275 |
+
if self.sd_model.cond_stage_key == "edit":
|
276 |
+
return self.edit_image_conditioning(source_image)
|
277 |
+
|
278 |
+
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
279 |
+
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
280 |
+
|
281 |
+
if self.sampler.conditioning_key == "crossattn-adm":
|
282 |
+
return self.unclip_image_conditioning(source_image)
|
283 |
+
|
284 |
+
# Dummy zero conditioning if we're not using inpainting or depth model.
|
285 |
+
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
286 |
+
|
287 |
+
def init(self, all_prompts, all_seeds, all_subseeds):
|
288 |
+
pass
|
289 |
+
|
290 |
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
291 |
+
raise NotImplementedError()
|
292 |
+
|
293 |
+
def close(self):
|
294 |
+
self.sampler = None
|
295 |
+
self.c = None
|
296 |
+
self.uc = None
|
297 |
+
if not opts.experimental_persistent_cond_cache:
|
298 |
+
StableDiffusionProcessing.cached_c = [None, None]
|
299 |
+
StableDiffusionProcessing.cached_uc = [None, None]
|
300 |
+
|
301 |
+
def get_token_merging_ratio(self, for_hr=False):
|
302 |
+
if for_hr:
|
303 |
+
return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
|
304 |
+
|
305 |
+
return self.token_merging_ratio or opts.token_merging_ratio
|
306 |
+
|
307 |
+
def setup_prompts(self):
|
308 |
+
if type(self.prompt) == list:
|
309 |
+
self.all_prompts = self.prompt
|
310 |
+
else:
|
311 |
+
self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
|
312 |
+
|
313 |
+
if type(self.negative_prompt) == list:
|
314 |
+
self.all_negative_prompts = self.negative_prompt
|
315 |
+
else:
|
316 |
+
self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]
|
317 |
+
|
318 |
+
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
|
319 |
+
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
|
320 |
+
|
321 |
+
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
|
322 |
+
"""
|
323 |
+
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
324 |
+
using a cache to store the result if the same arguments have been used before.
|
325 |
+
|
326 |
+
cache is an array containing two elements. The first element is a tuple
|
327 |
+
representing the previously used arguments, or None if no arguments
|
328 |
+
have been used before. The second element is where the previously
|
329 |
+
computed result is stored.
|
330 |
+
|
331 |
+
caches is a list with items described above.
|
332 |
+
"""
|
333 |
+
|
334 |
+
cached_params = (
|
335 |
+
required_prompts,
|
336 |
+
steps,
|
337 |
+
opts.CLIP_stop_at_last_layers,
|
338 |
+
shared.sd_model.sd_checkpoint_info,
|
339 |
+
extra_network_data,
|
340 |
+
opts.sdxl_crop_left,
|
341 |
+
opts.sdxl_crop_top,
|
342 |
+
self.width,
|
343 |
+
self.height,
|
344 |
+
)
|
345 |
+
|
346 |
+
for cache in caches:
|
347 |
+
if cache[0] is not None and cached_params == cache[0]:
|
348 |
+
return cache[1]
|
349 |
+
|
350 |
+
cache = caches[0]
|
351 |
+
|
352 |
+
with devices.autocast():
|
353 |
+
cache[1] = function(shared.sd_model, required_prompts, steps)
|
354 |
+
|
355 |
+
cache[0] = cached_params
|
356 |
+
return cache[1]
|
357 |
+
|
358 |
+
def setup_conds(self):
|
359 |
+
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
|
360 |
+
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
|
361 |
+
|
362 |
+
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
363 |
+
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
364 |
+
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
|
365 |
+
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
|
366 |
+
|
367 |
+
def parse_extra_network_prompts(self):
|
368 |
+
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
369 |
+
|
370 |
+
|
371 |
+
class Processed:
|
372 |
+
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
373 |
+
self.images = images_list
|
374 |
+
self.prompt = p.prompt
|
375 |
+
self.negative_prompt = p.negative_prompt
|
376 |
+
self.seed = seed
|
377 |
+
self.subseed = subseed
|
378 |
+
self.subseed_strength = p.subseed_strength
|
379 |
+
self.info = info
|
380 |
+
self.comments = comments
|
381 |
+
self.width = p.width
|
382 |
+
self.height = p.height
|
383 |
+
self.sampler_name = p.sampler_name
|
384 |
+
self.cfg_scale = p.cfg_scale
|
385 |
+
self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
386 |
+
self.steps = p.steps
|
387 |
+
self.batch_size = p.batch_size
|
388 |
+
self.restore_faces = p.restore_faces
|
389 |
+
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
390 |
+
self.sd_model_hash = shared.sd_model.sd_model_hash
|
391 |
+
self.seed_resize_from_w = p.seed_resize_from_w
|
392 |
+
self.seed_resize_from_h = p.seed_resize_from_h
|
393 |
+
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
394 |
+
self.extra_generation_params = p.extra_generation_params
|
395 |
+
self.index_of_first_image = index_of_first_image
|
396 |
+
self.styles = p.styles
|
397 |
+
self.job_timestamp = state.job_timestamp
|
398 |
+
self.clip_skip = opts.CLIP_stop_at_last_layers
|
399 |
+
self.token_merging_ratio = p.token_merging_ratio
|
400 |
+
self.token_merging_ratio_hr = p.token_merging_ratio_hr
|
401 |
+
|
402 |
+
self.eta = p.eta
|
403 |
+
self.ddim_discretize = p.ddim_discretize
|
404 |
+
self.s_churn = p.s_churn
|
405 |
+
self.s_tmin = p.s_tmin
|
406 |
+
self.s_tmax = p.s_tmax
|
407 |
+
self.s_noise = p.s_noise
|
408 |
+
self.s_min_uncond = p.s_min_uncond
|
409 |
+
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
410 |
+
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
411 |
+
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
412 |
+
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
|
413 |
+
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
414 |
+
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
|
415 |
+
|
416 |
+
self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
|
417 |
+
self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
|
418 |
+
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
419 |
+
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
420 |
+
self.infotexts = infotexts or [info]
|
421 |
+
|
422 |
+
def js(self):
|
423 |
+
obj = {
|
424 |
+
"prompt": self.all_prompts[0],
|
425 |
+
"all_prompts": self.all_prompts,
|
426 |
+
"negative_prompt": self.all_negative_prompts[0],
|
427 |
+
"all_negative_prompts": self.all_negative_prompts,
|
428 |
+
"seed": self.seed,
|
429 |
+
"all_seeds": self.all_seeds,
|
430 |
+
"subseed": self.subseed,
|
431 |
+
"all_subseeds": self.all_subseeds,
|
432 |
+
"subseed_strength": self.subseed_strength,
|
433 |
+
"width": self.width,
|
434 |
+
"height": self.height,
|
435 |
+
"sampler_name": self.sampler_name,
|
436 |
+
"cfg_scale": self.cfg_scale,
|
437 |
+
"steps": self.steps,
|
438 |
+
"batch_size": self.batch_size,
|
439 |
+
"restore_faces": self.restore_faces,
|
440 |
+
"face_restoration_model": self.face_restoration_model,
|
441 |
+
"sd_model_hash": self.sd_model_hash,
|
442 |
+
"seed_resize_from_w": self.seed_resize_from_w,
|
443 |
+
"seed_resize_from_h": self.seed_resize_from_h,
|
444 |
+
"denoising_strength": self.denoising_strength,
|
445 |
+
"extra_generation_params": self.extra_generation_params,
|
446 |
+
"index_of_first_image": self.index_of_first_image,
|
447 |
+
"infotexts": self.infotexts,
|
448 |
+
"styles": self.styles,
|
449 |
+
"job_timestamp": self.job_timestamp,
|
450 |
+
"clip_skip": self.clip_skip,
|
451 |
+
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
|
452 |
+
}
|
453 |
+
|
454 |
+
return json.dumps(obj)
|
455 |
+
|
456 |
+
def infotext(self, p: StableDiffusionProcessing, index):
|
457 |
+
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
458 |
+
|
459 |
+
def get_token_merging_ratio(self, for_hr=False):
|
460 |
+
return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
|
461 |
+
|
462 |
+
|
463 |
+
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
464 |
+
def slerp(val, low, high):
|
465 |
+
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
466 |
+
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
467 |
+
dot = (low_norm*high_norm).sum(1)
|
468 |
+
|
469 |
+
if dot.mean() > 0.9995:
|
470 |
+
return low * val + high * (1 - val)
|
471 |
+
|
472 |
+
omega = torch.acos(dot)
|
473 |
+
so = torch.sin(omega)
|
474 |
+
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
475 |
+
return res
|
476 |
+
|
477 |
+
|
478 |
+
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
479 |
+
eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
|
480 |
+
xs = []
|
481 |
+
|
482 |
+
# if we have multiple seeds, this means we are working with batch size>1; this then
|
483 |
+
# enables the generation of additional tensors with noise that the sampler will use during its processing.
|
484 |
+
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
|
485 |
+
# produce the same images as with two batches [100], [101].
|
486 |
+
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
|
487 |
+
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
488 |
+
else:
|
489 |
+
sampler_noises = None
|
490 |
+
|
491 |
+
for i, seed in enumerate(seeds):
|
492 |
+
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
493 |
+
|
494 |
+
subnoise = None
|
495 |
+
if subseeds is not None:
|
496 |
+
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
497 |
+
|
498 |
+
subnoise = devices.randn(subseed, noise_shape)
|
499 |
+
|
500 |
+
# randn results depend on device; gpu and cpu get different results for same seed;
|
501 |
+
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
502 |
+
# but the original script had it like this, so I do not dare change it for now because
|
503 |
+
# it will break everyone's seeds.
|
504 |
+
noise = devices.randn(seed, noise_shape)
|
505 |
+
|
506 |
+
if subnoise is not None:
|
507 |
+
noise = slerp(subseed_strength, noise, subnoise)
|
508 |
+
|
509 |
+
if noise_shape != shape:
|
510 |
+
x = devices.randn(seed, shape)
|
511 |
+
dx = (shape[2] - noise_shape[2]) // 2
|
512 |
+
dy = (shape[1] - noise_shape[1]) // 2
|
513 |
+
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
514 |
+
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
515 |
+
tx = 0 if dx < 0 else dx
|
516 |
+
ty = 0 if dy < 0 else dy
|
517 |
+
dx = max(-dx, 0)
|
518 |
+
dy = max(-dy, 0)
|
519 |
+
|
520 |
+
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
|
521 |
+
noise = x
|
522 |
+
|
523 |
+
if sampler_noises is not None:
|
524 |
+
cnt = p.sampler.number_of_needed_noises(p)
|
525 |
+
|
526 |
+
if eta_noise_seed_delta > 0:
|
527 |
+
torch.manual_seed(seed + eta_noise_seed_delta)
|
528 |
+
|
529 |
+
for j in range(cnt):
|
530 |
+
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
531 |
+
|
532 |
+
xs.append(noise)
|
533 |
+
|
534 |
+
if sampler_noises is not None:
|
535 |
+
p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
|
536 |
+
|
537 |
+
x = torch.stack(xs).to(shared.device)
|
538 |
+
return x
|
539 |
+
|
540 |
+
|
541 |
+
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
542 |
+
samples = []
|
543 |
+
|
544 |
+
for i in range(batch.shape[0]):
|
545 |
+
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
546 |
+
|
547 |
+
if check_for_nans:
|
548 |
+
try:
|
549 |
+
devices.test_for_nans(sample, "vae")
|
550 |
+
except devices.NansException as e:
|
551 |
+
if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
|
552 |
+
raise e
|
553 |
+
|
554 |
+
errors.print_error_explanation(
|
555 |
+
"A tensor with all NaNs was produced in VAE.\n"
|
556 |
+
"Web UI will now convert VAE into 32-bit float and retry.\n"
|
557 |
+
"To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"
|
558 |
+
"To always start with 32-bit VAE, use --no-half-vae commandline flag."
|
559 |
+
)
|
560 |
+
|
561 |
+
devices.dtype_vae = torch.float32
|
562 |
+
model.first_stage_model.to(devices.dtype_vae)
|
563 |
+
batch = batch.to(devices.dtype_vae)
|
564 |
+
|
565 |
+
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
566 |
+
|
567 |
+
if target_device is not None:
|
568 |
+
sample = sample.to(target_device)
|
569 |
+
|
570 |
+
samples.append(sample)
|
571 |
+
|
572 |
+
return samples
|
573 |
+
|
574 |
+
|
575 |
+
def decode_first_stage(model, x):
|
576 |
+
x = model.decode_first_stage(x.to(devices.dtype_vae))
|
577 |
+
|
578 |
+
return x
|
579 |
+
|
580 |
+
|
581 |
+
def get_fixed_seed(seed):
|
582 |
+
if seed is None or seed == '' or seed == -1:
|
583 |
+
return int(random.randrange(4294967294))
|
584 |
+
|
585 |
+
return seed
|
586 |
+
|
587 |
+
|
588 |
+
def fix_seed(p):
|
589 |
+
p.seed = get_fixed_seed(p.seed)
|
590 |
+
p.subseed = get_fixed_seed(p.subseed)
|
591 |
+
|
592 |
+
|
593 |
+
def program_version():
|
594 |
+
import launch
|
595 |
+
|
596 |
+
res = launch.git_tag()
|
597 |
+
if res == "<none>":
|
598 |
+
res = None
|
599 |
+
|
600 |
+
return res
|
601 |
+
|
602 |
+
|
603 |
+
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
|
604 |
+
index = position_in_batch + iteration * p.batch_size
|
605 |
+
|
606 |
+
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
607 |
+
enable_hr = getattr(p, 'enable_hr', False)
|
608 |
+
token_merging_ratio = p.get_token_merging_ratio()
|
609 |
+
token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
|
610 |
+
|
611 |
+
uses_ensd = opts.eta_noise_seed_delta != 0
|
612 |
+
if uses_ensd:
|
613 |
+
uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
|
614 |
+
|
615 |
+
generation_params = {
|
616 |
+
"Steps": p.steps,
|
617 |
+
"Sampler": p.sampler_name,
|
618 |
+
"CFG scale": p.cfg_scale,
|
619 |
+
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
620 |
+
"Seed": all_seeds[index],
|
621 |
+
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
622 |
+
"Size": f"{p.width}x{p.height}",
|
623 |
+
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
624 |
+
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
|
625 |
+
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
626 |
+
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
627 |
+
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
628 |
+
"Denoising strength": getattr(p, 'denoising_strength', None),
|
629 |
+
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
630 |
+
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
631 |
+
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
|
632 |
+
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
|
633 |
+
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
|
634 |
+
"Init image hash": getattr(p, 'init_img_hash', None),
|
635 |
+
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
636 |
+
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
637 |
+
**p.extra_generation_params,
|
638 |
+
"Version": program_version() if opts.add_version_to_infotext else None,
|
639 |
+
"User": p.user if opts.add_user_name_to_info else None,
|
640 |
+
}
|
641 |
+
|
642 |
+
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
643 |
+
|
644 |
+
prompt_text = p.prompt if use_main_prompt else all_prompts[index]
|
645 |
+
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
646 |
+
|
647 |
+
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
|
648 |
+
|
649 |
+
|
650 |
+
def process_images(p: StableDiffusionProcessing) -> Processed:
|
651 |
+
if p.scripts is not None:
|
652 |
+
p.scripts.before_process(p)
|
653 |
+
|
654 |
+
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
655 |
+
|
656 |
+
try:
|
657 |
+
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
658 |
+
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
659 |
+
p.override_settings.pop('sd_model_checkpoint', None)
|
660 |
+
sd_models.reload_model_weights()
|
661 |
+
|
662 |
+
for k, v in p.override_settings.items():
|
663 |
+
setattr(opts, k, v)
|
664 |
+
|
665 |
+
if k == 'sd_model_checkpoint':
|
666 |
+
sd_models.reload_model_weights()
|
667 |
+
|
668 |
+
if k == 'sd_vae':
|
669 |
+
sd_vae.reload_vae_weights()
|
670 |
+
|
671 |
+
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
|
672 |
+
|
673 |
+
res = process_images_inner(p)
|
674 |
+
|
675 |
+
finally:
|
676 |
+
sd_models.apply_token_merging(p.sd_model, 0)
|
677 |
+
|
678 |
+
# restore opts to original state
|
679 |
+
if p.override_settings_restore_afterwards:
|
680 |
+
for k, v in stored_opts.items():
|
681 |
+
setattr(opts, k, v)
|
682 |
+
|
683 |
+
if k == 'sd_vae':
|
684 |
+
sd_vae.reload_vae_weights()
|
685 |
+
|
686 |
+
return res
|
687 |
+
|
688 |
+
|
689 |
+
def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
690 |
+
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
691 |
+
|
692 |
+
if type(p.prompt) == list:
|
693 |
+
assert(len(p.prompt) > 0)
|
694 |
+
else:
|
695 |
+
assert p.prompt is not None
|
696 |
+
|
697 |
+
devices.torch_gc()
|
698 |
+
|
699 |
+
seed = get_fixed_seed(p.seed)
|
700 |
+
subseed = get_fixed_seed(p.subseed)
|
701 |
+
|
702 |
+
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
703 |
+
modules.sd_hijack.model_hijack.clear_comments()
|
704 |
+
|
705 |
+
comments = {}
|
706 |
+
|
707 |
+
p.setup_prompts()
|
708 |
+
|
709 |
+
if type(seed) == list:
|
710 |
+
p.all_seeds = seed
|
711 |
+
else:
|
712 |
+
p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
|
713 |
+
|
714 |
+
if type(subseed) == list:
|
715 |
+
p.all_subseeds = subseed
|
716 |
+
else:
|
717 |
+
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
718 |
+
|
719 |
+
def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
|
720 |
+
all_prompts = p.all_prompts[:]
|
721 |
+
all_negative_prompts = p.all_negative_prompts[:]
|
722 |
+
all_seeds = p.all_seeds[:]
|
723 |
+
all_subseeds = p.all_subseeds[:]
|
724 |
+
|
725 |
+
# apply changes to generation data
|
726 |
+
all_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.prompts
|
727 |
+
all_negative_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.negative_prompts
|
728 |
+
all_seeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.seeds
|
729 |
+
all_subseeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.subseeds
|
730 |
+
|
731 |
+
# update p.all_negative_prompts in case extensions changed the size of the batch
|
732 |
+
# create_infotext below uses it
|
733 |
+
old_negative_prompts = p.all_negative_prompts
|
734 |
+
p.all_negative_prompts = all_negative_prompts
|
735 |
+
|
736 |
+
try:
|
737 |
+
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
|
738 |
+
finally:
|
739 |
+
# restore p.all_negative_prompts in case extensions changed the size of the batch
|
740 |
+
p.all_negative_prompts = old_negative_prompts
|
741 |
+
|
742 |
+
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
743 |
+
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
744 |
+
|
745 |
+
if p.scripts is not None:
|
746 |
+
p.scripts.process(p)
|
747 |
+
|
748 |
+
infotexts = []
|
749 |
+
output_images = []
|
750 |
+
|
751 |
+
with torch.no_grad(), p.sd_model.ema_scope():
|
752 |
+
with devices.autocast():
|
753 |
+
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
754 |
+
|
755 |
+
# for OSX, loading the model during sampling changes the generated picture, so it is loaded here
|
756 |
+
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
|
757 |
+
sd_vae_approx.model()
|
758 |
+
|
759 |
+
sd_unet.apply_unet()
|
760 |
+
|
761 |
+
if state.job_count == -1:
|
762 |
+
state.job_count = p.n_iter
|
763 |
+
|
764 |
+
for n in range(p.n_iter):
|
765 |
+
p.iteration = n
|
766 |
+
|
767 |
+
if state.skipped:
|
768 |
+
state.skipped = False
|
769 |
+
|
770 |
+
if state.interrupted:
|
771 |
+
break
|
772 |
+
|
773 |
+
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
774 |
+
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
775 |
+
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
776 |
+
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
777 |
+
|
778 |
+
if p.scripts is not None:
|
779 |
+
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
780 |
+
|
781 |
+
if len(p.prompts) == 0:
|
782 |
+
break
|
783 |
+
|
784 |
+
p.parse_extra_network_prompts()
|
785 |
+
|
786 |
+
if not p.disable_extra_networks:
|
787 |
+
with devices.autocast():
|
788 |
+
extra_networks.activate(p, p.extra_network_data)
|
789 |
+
|
790 |
+
if p.scripts is not None:
|
791 |
+
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
792 |
+
|
793 |
+
# params.txt should be saved after scripts.process_batch, since the
|
794 |
+
# infotext could be modified by that callback
|
795 |
+
# Example: a wildcard processed by process_batch sets an extra model
|
796 |
+
# strength, which is saved as "Model Strength: 1.0" in the infotext
|
797 |
+
if n == 0:
|
798 |
+
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
799 |
+
processed = Processed(p, [], p.seed, "")
|
800 |
+
file.write(processed.infotext(p, 0))
|
801 |
+
|
802 |
+
p.setup_conds()
|
803 |
+
|
804 |
+
for comment in model_hijack.comments:
|
805 |
+
comments[comment] = 1
|
806 |
+
|
807 |
+
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
808 |
+
|
809 |
+
if p.n_iter > 1:
|
810 |
+
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
811 |
+
|
812 |
+
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
813 |
+
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
814 |
+
|
815 |
+
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
816 |
+
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
817 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
818 |
+
|
819 |
+
del samples_ddim
|
820 |
+
|
821 |
+
if lowvram.is_enabled(shared.sd_model):
|
822 |
+
lowvram.send_everything_to_cpu()
|
823 |
+
|
824 |
+
devices.torch_gc()
|
825 |
+
|
826 |
+
if p.scripts is not None:
|
827 |
+
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
828 |
+
|
829 |
+
postprocess_batch_list_args = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
|
830 |
+
p.scripts.postprocess_batch_list(p, postprocess_batch_list_args, batch_number=n)
|
831 |
+
x_samples_ddim = postprocess_batch_list_args.images
|
832 |
+
|
833 |
+
for i, x_sample in enumerate(x_samples_ddim):
|
834 |
+
p.batch_index = i
|
835 |
+
|
836 |
+
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
837 |
+
x_sample = x_sample.astype(np.uint8)
|
838 |
+
|
839 |
+
if p.restore_faces:
|
840 |
+
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
841 |
+
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
842 |
+
|
843 |
+
devices.torch_gc()
|
844 |
+
|
845 |
+
x_sample = modules.face_restoration.restore_faces(x_sample)
|
846 |
+
devices.torch_gc()
|
847 |
+
|
848 |
+
image = Image.fromarray(x_sample)
|
849 |
+
|
850 |
+
if p.scripts is not None:
|
851 |
+
pp = scripts.PostprocessImageArgs(image)
|
852 |
+
p.scripts.postprocess_image(p, pp)
|
853 |
+
image = pp.image
|
854 |
+
|
855 |
+
if p.color_corrections is not None and i < len(p.color_corrections):
|
856 |
+
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
857 |
+
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
858 |
+
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
|
859 |
+
image = apply_color_correction(p.color_corrections[i], image)
|
860 |
+
|
861 |
+
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
862 |
+
|
863 |
+
if opts.samples_save and not p.do_not_save_samples:
|
864 |
+
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
865 |
+
|
866 |
+
text = infotext(n, i)
|
867 |
+
infotexts.append(text)
|
868 |
+
if opts.enable_pnginfo:
|
869 |
+
image.info["parameters"] = text
|
870 |
+
output_images.append(image)
|
871 |
+
|
872 |
+
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
|
873 |
+
image_mask = p.mask_for_overlay.convert('RGB')
|
874 |
+
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
875 |
+
|
876 |
+
if opts.save_mask:
|
877 |
+
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
|
878 |
+
|
879 |
+
if opts.save_mask_composite:
|
880 |
+
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
|
881 |
+
|
882 |
+
if opts.return_mask:
|
883 |
+
output_images.append(image_mask)
|
884 |
+
|
885 |
+
if opts.return_mask_composite:
|
886 |
+
output_images.append(image_mask_composite)
|
887 |
+
|
888 |
+
del x_samples_ddim
|
889 |
+
|
890 |
+
devices.torch_gc()
|
891 |
+
|
892 |
+
state.nextjob()
|
893 |
+
|
894 |
+
p.color_corrections = None
|
895 |
+
|
896 |
+
index_of_first_image = 0
|
897 |
+
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
|
898 |
+
if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
|
899 |
+
grid = images.image_grid(output_images, p.batch_size)
|
900 |
+
|
901 |
+
if opts.return_grid:
|
902 |
+
text = infotext(use_main_prompt=True)
|
903 |
+
infotexts.insert(0, text)
|
904 |
+
if opts.enable_pnginfo:
|
905 |
+
grid.info["parameters"] = text
|
906 |
+
output_images.insert(0, grid)
|
907 |
+
index_of_first_image = 1
|
908 |
+
|
909 |
+
if opts.grid_save:
|
910 |
+
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
911 |
+
|
912 |
+
if not p.disable_extra_networks and p.extra_network_data:
|
913 |
+
extra_networks.deactivate(p, p.extra_network_data)
|
914 |
+
|
915 |
+
devices.torch_gc()
|
916 |
+
|
917 |
+
res = Processed(
|
918 |
+
p,
|
919 |
+
images_list=output_images,
|
920 |
+
seed=p.all_seeds[0],
|
921 |
+
info=infotext(),
|
922 |
+
comments="".join(f"{comment}\n" for comment in comments),
|
923 |
+
subseed=p.all_subseeds[0],
|
924 |
+
index_of_first_image=index_of_first_image,
|
925 |
+
infotexts=infotexts,
|
926 |
+
)
|
927 |
+
|
928 |
+
if p.scripts is not None:
|
929 |
+
p.scripts.postprocess(p, res)
|
930 |
+
|
931 |
+
return res
|
932 |
+
|
933 |
+
|
934 |
+
def old_hires_fix_first_pass_dimensions(width, height):
|
935 |
+
"""old algorithm for auto-calculating first pass size"""
|
936 |
+
|
937 |
+
desired_pixel_count = 512 * 512
|
938 |
+
actual_pixel_count = width * height
|
939 |
+
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
940 |
+
width = math.ceil(scale * width / 64) * 64
|
941 |
+
height = math.ceil(scale * height / 64) * 64
|
942 |
+
|
943 |
+
return width, height
|
944 |
+
|
945 |
+
|
946 |
+
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
947 |
+
sampler = None
|
948 |
+
cached_hr_uc = [None, None]
|
949 |
+
cached_hr_c = [None, None]
|
950 |
+
|
951 |
+
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
|
952 |
+
super().__init__(**kwargs)
|
953 |
+
self.enable_hr = enable_hr
|
954 |
+
self.denoising_strength = denoising_strength
|
955 |
+
self.hr_scale = hr_scale
|
956 |
+
self.hr_upscaler = hr_upscaler
|
957 |
+
self.hr_second_pass_steps = hr_second_pass_steps
|
958 |
+
self.hr_resize_x = hr_resize_x
|
959 |
+
self.hr_resize_y = hr_resize_y
|
960 |
+
self.hr_upscale_to_x = hr_resize_x
|
961 |
+
self.hr_upscale_to_y = hr_resize_y
|
962 |
+
self.hr_sampler_name = hr_sampler_name
|
963 |
+
self.hr_prompt = hr_prompt
|
964 |
+
self.hr_negative_prompt = hr_negative_prompt
|
965 |
+
self.all_hr_prompts = None
|
966 |
+
self.all_hr_negative_prompts = None
|
967 |
+
|
968 |
+
if firstphase_width != 0 or firstphase_height != 0:
|
969 |
+
self.hr_upscale_to_x = self.width
|
970 |
+
self.hr_upscale_to_y = self.height
|
971 |
+
self.width = firstphase_width
|
972 |
+
self.height = firstphase_height
|
973 |
+
|
974 |
+
self.truncate_x = 0
|
975 |
+
self.truncate_y = 0
|
976 |
+
self.applied_old_hires_behavior_to = None
|
977 |
+
|
978 |
+
self.hr_prompts = None
|
979 |
+
self.hr_negative_prompts = None
|
980 |
+
self.hr_extra_network_data = None
|
981 |
+
|
982 |
+
self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
|
983 |
+
self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
|
984 |
+
self.hr_c = None
|
985 |
+
self.hr_uc = None
|
986 |
+
|
987 |
+
def init(self, all_prompts, all_seeds, all_subseeds):
|
988 |
+
if self.enable_hr:
|
989 |
+
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
990 |
+
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
991 |
+
|
992 |
+
if tuple(self.hr_prompt) != tuple(self.prompt):
|
993 |
+
self.extra_generation_params["Hires prompt"] = self.hr_prompt
|
994 |
+
|
995 |
+
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
|
996 |
+
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
|
997 |
+
|
998 |
+
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
999 |
+
self.hr_resize_x = self.width
|
1000 |
+
self.hr_resize_y = self.height
|
1001 |
+
self.hr_upscale_to_x = self.width
|
1002 |
+
self.hr_upscale_to_y = self.height
|
1003 |
+
|
1004 |
+
self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
|
1005 |
+
self.applied_old_hires_behavior_to = (self.width, self.height)
|
1006 |
+
|
1007 |
+
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
|
1008 |
+
self.extra_generation_params["Hires upscale"] = self.hr_scale
|
1009 |
+
self.hr_upscale_to_x = int(self.width * self.hr_scale)
|
1010 |
+
self.hr_upscale_to_y = int(self.height * self.hr_scale)
|
1011 |
+
else:
|
1012 |
+
self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
|
1013 |
+
|
1014 |
+
if self.hr_resize_y == 0:
|
1015 |
+
self.hr_upscale_to_x = self.hr_resize_x
|
1016 |
+
self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
|
1017 |
+
elif self.hr_resize_x == 0:
|
1018 |
+
self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
|
1019 |
+
self.hr_upscale_to_y = self.hr_resize_y
|
1020 |
+
else:
|
1021 |
+
target_w = self.hr_resize_x
|
1022 |
+
target_h = self.hr_resize_y
|
1023 |
+
src_ratio = self.width / self.height
|
1024 |
+
dst_ratio = self.hr_resize_x / self.hr_resize_y
|
1025 |
+
|
1026 |
+
if src_ratio < dst_ratio:
|
1027 |
+
self.hr_upscale_to_x = self.hr_resize_x
|
1028 |
+
self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
|
1029 |
+
else:
|
1030 |
+
self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
|
1031 |
+
self.hr_upscale_to_y = self.hr_resize_y
|
1032 |
+
|
1033 |
+
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
|
1034 |
+
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
|
1035 |
+
|
1036 |
+
# special case: the user has chosen to do nothing
|
1037 |
+
if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
|
1038 |
+
self.enable_hr = False
|
1039 |
+
self.denoising_strength = None
|
1040 |
+
self.extra_generation_params.pop("Hires upscale", None)
|
1041 |
+
self.extra_generation_params.pop("Hires resize", None)
|
1042 |
+
return
|
1043 |
+
|
1044 |
+
if not state.processing_has_refined_job_count:
|
1045 |
+
if state.job_count == -1:
|
1046 |
+
state.job_count = self.n_iter
|
1047 |
+
|
1048 |
+
shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
|
1049 |
+
state.job_count = state.job_count * 2
|
1050 |
+
state.processing_has_refined_job_count = True
|
1051 |
+
|
1052 |
+
if self.hr_second_pass_steps:
|
1053 |
+
self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
|
1054 |
+
|
1055 |
+
if self.hr_upscaler is not None:
|
1056 |
+
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
|
1057 |
+
|
1058 |
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
1059 |
+
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
1060 |
+
|
1061 |
+
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
1062 |
+
if self.enable_hr and latent_scale_mode is None:
|
1063 |
+
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
1064 |
+
raise Exception(f"could not find upscaler named {self.hr_upscaler}")
|
1065 |
+
|
1066 |
+
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
1067 |
+
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
1068 |
+
|
1069 |
+
if not self.enable_hr:
|
1070 |
+
return samples
|
1071 |
+
|
1072 |
+
self.is_hr_pass = True
|
1073 |
+
|
1074 |
+
target_width = self.hr_upscale_to_x
|
1075 |
+
target_height = self.hr_upscale_to_y
|
1076 |
+
|
1077 |
+
def save_intermediate(image, index):
|
1078 |
+
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
|
1079 |
+
|
1080 |
+
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
|
1081 |
+
return
|
1082 |
+
|
1083 |
+
if not isinstance(image, Image.Image):
|
1084 |
+
image = sd_samplers.sample_to_image(image, index, approximation=0)
|
1085 |
+
|
1086 |
+
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
|
1087 |
+
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
|
1088 |
+
|
1089 |
+
if latent_scale_mode is not None:
|
1090 |
+
for i in range(samples.shape[0]):
|
1091 |
+
save_intermediate(samples, i)
|
1092 |
+
|
1093 |
+
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
|
1094 |
+
|
1095 |
+
# Avoid making the inpainting conditioning unless necessary as
|
1096 |
+
# this does need some extra compute to decode / encode the image again.
|
1097 |
+
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
|
1098 |
+
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
|
1099 |
+
else:
|
1100 |
+
image_conditioning = self.txt2img_image_conditioning(samples)
|
1101 |
+
else:
|
1102 |
+
decoded_samples = decode_first_stage(self.sd_model, samples)
|
1103 |
+
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
1104 |
+
|
1105 |
+
batch_images = []
|
1106 |
+
for i, x_sample in enumerate(lowres_samples):
|
1107 |
+
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
1108 |
+
x_sample = x_sample.astype(np.uint8)
|
1109 |
+
image = Image.fromarray(x_sample)
|
1110 |
+
|
1111 |
+
save_intermediate(image, i)
|
1112 |
+
|
1113 |
+
image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
|
1114 |
+
image = np.array(image).astype(np.float32) / 255.0
|
1115 |
+
image = np.moveaxis(image, 2, 0)
|
1116 |
+
batch_images.append(image)
|
1117 |
+
|
1118 |
+
decoded_samples = torch.from_numpy(np.array(batch_images))
|
1119 |
+
decoded_samples = decoded_samples.to(shared.device)
|
1120 |
+
decoded_samples = 2. * decoded_samples - 1.
|
1121 |
+
|
1122 |
+
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
1123 |
+
|
1124 |
+
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
|
1125 |
+
|
1126 |
+
shared.state.nextjob()
|
1127 |
+
|
1128 |
+
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
1129 |
+
|
1130 |
+
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
|
1131 |
+
img2img_sampler_name = 'DDIM'
|
1132 |
+
|
1133 |
+
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
1134 |
+
|
1135 |
+
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
1136 |
+
|
1137 |
+
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
1138 |
+
|
1139 |
+
# GC now before running the next img2img to prevent running out of memory
|
1140 |
+
x = None
|
1141 |
+
devices.torch_gc()
|
1142 |
+
|
1143 |
+
if not self.disable_extra_networks:
|
1144 |
+
with devices.autocast():
|
1145 |
+
extra_networks.activate(self, self.hr_extra_network_data)
|
1146 |
+
|
1147 |
+
with devices.autocast():
|
1148 |
+
self.calculate_hr_conds()
|
1149 |
+
|
1150 |
+
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
|
1151 |
+
|
1152 |
+
if self.scripts is not None:
|
1153 |
+
self.scripts.before_hr(self)
|
1154 |
+
|
1155 |
+
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
1156 |
+
|
1157 |
+
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
1158 |
+
|
1159 |
+
self.is_hr_pass = False
|
1160 |
+
|
1161 |
+
return samples
|
1162 |
+
|
1163 |
+
def close(self):
|
1164 |
+
super().close()
|
1165 |
+
self.hr_c = None
|
1166 |
+
self.hr_uc = None
|
1167 |
+
if not opts.experimental_persistent_cond_cache:
|
1168 |
+
StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
|
1169 |
+
StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
|
1170 |
+
|
1171 |
+
def setup_prompts(self):
|
1172 |
+
super().setup_prompts()
|
1173 |
+
|
1174 |
+
if not self.enable_hr:
|
1175 |
+
return
|
1176 |
+
|
1177 |
+
if self.hr_prompt == '':
|
1178 |
+
self.hr_prompt = self.prompt
|
1179 |
+
|
1180 |
+
if self.hr_negative_prompt == '':
|
1181 |
+
self.hr_negative_prompt = self.negative_prompt
|
1182 |
+
|
1183 |
+
if type(self.hr_prompt) == list:
|
1184 |
+
self.all_hr_prompts = self.hr_prompt
|
1185 |
+
else:
|
1186 |
+
self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
|
1187 |
+
|
1188 |
+
if type(self.hr_negative_prompt) == list:
|
1189 |
+
self.all_hr_negative_prompts = self.hr_negative_prompt
|
1190 |
+
else:
|
1191 |
+
self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
|
1192 |
+
|
1193 |
+
self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
|
1194 |
+
self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
|
1195 |
+
|
1196 |
+
def calculate_hr_conds(self):
|
1197 |
+
if self.hr_c is not None:
|
1198 |
+
return
|
1199 |
+
|
1200 |
+
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
1201 |
+
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
1202 |
+
|
1203 |
+
def setup_conds(self):
|
1204 |
+
super().setup_conds()
|
1205 |
+
|
1206 |
+
self.hr_uc = None
|
1207 |
+
self.hr_c = None
|
1208 |
+
|
1209 |
+
if self.enable_hr:
|
1210 |
+
if shared.opts.hires_fix_use_firstpass_conds:
|
1211 |
+
self.calculate_hr_conds()
|
1212 |
+
|
1213 |
+
elif lowvram.is_enabled(shared.sd_model): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
|
1214 |
+
with devices.autocast():
|
1215 |
+
extra_networks.activate(self, self.hr_extra_network_data)
|
1216 |
+
|
1217 |
+
self.calculate_hr_conds()
|
1218 |
+
|
1219 |
+
with devices.autocast():
|
1220 |
+
extra_networks.activate(self, self.extra_network_data)
|
1221 |
+
|
1222 |
+
def parse_extra_network_prompts(self):
|
1223 |
+
res = super().parse_extra_network_prompts()
|
1224 |
+
|
1225 |
+
if self.enable_hr:
|
1226 |
+
self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
|
1227 |
+
self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
|
1228 |
+
|
1229 |
+
self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)
|
1230 |
+
|
1231 |
+
return res
|
1232 |
+
|
1233 |
+
|
1234 |
+
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
1235 |
+
sampler = None
|
1236 |
+
|
1237 |
+
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
1238 |
+
super().__init__(**kwargs)
|
1239 |
+
|
1240 |
+
self.init_images = init_images
|
1241 |
+
self.resize_mode: int = resize_mode
|
1242 |
+
self.denoising_strength: float = denoising_strength
|
1243 |
+
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
1244 |
+
self.init_latent = None
|
1245 |
+
self.image_mask = mask
|
1246 |
+
self.latent_mask = None
|
1247 |
+
self.mask_for_overlay = None
|
1248 |
+
if mask_blur is not None:
|
1249 |
+
mask_blur_x = mask_blur
|
1250 |
+
mask_blur_y = mask_blur
|
1251 |
+
self.mask_blur_x = mask_blur_x
|
1252 |
+
self.mask_blur_y = mask_blur_y
|
1253 |
+
self.inpainting_fill = inpainting_fill
|
1254 |
+
self.inpaint_full_res = inpaint_full_res
|
1255 |
+
self.inpaint_full_res_padding = inpaint_full_res_padding
|
1256 |
+
self.inpainting_mask_invert = inpainting_mask_invert
|
1257 |
+
self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
|
1258 |
+
self.mask = None
|
1259 |
+
self.nmask = None
|
1260 |
+
self.image_conditioning = None
|
1261 |
+
|
1262 |
+
def init(self, all_prompts, all_seeds, all_subseeds):
|
1263 |
+
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
1264 |
+
crop_region = None
|
1265 |
+
|
1266 |
+
image_mask = self.image_mask
|
1267 |
+
|
1268 |
+
if image_mask is not None:
|
1269 |
+
image_mask = image_mask.convert('L')
|
1270 |
+
|
1271 |
+
if self.inpainting_mask_invert:
|
1272 |
+
image_mask = ImageOps.invert(image_mask)
|
1273 |
+
|
1274 |
+
if self.mask_blur_x > 0:
|
1275 |
+
np_mask = np.array(image_mask)
|
1276 |
+
kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1
|
1277 |
+
np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
|
1278 |
+
image_mask = Image.fromarray(np_mask)
|
1279 |
+
|
1280 |
+
if self.mask_blur_y > 0:
|
1281 |
+
np_mask = np.array(image_mask)
|
1282 |
+
kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1
|
1283 |
+
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
|
1284 |
+
image_mask = Image.fromarray(np_mask)
|
1285 |
+
|
1286 |
+
if self.inpaint_full_res:
|
1287 |
+
self.mask_for_overlay = image_mask
|
1288 |
+
mask = image_mask.convert('L')
|
1289 |
+
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
|
1290 |
+
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
|
1291 |
+
x1, y1, x2, y2 = crop_region
|
1292 |
+
|
1293 |
+
mask = mask.crop(crop_region)
|
1294 |
+
image_mask = images.resize_image(2, mask, self.width, self.height)
|
1295 |
+
self.paste_to = (x1, y1, x2-x1, y2-y1)
|
1296 |
+
else:
|
1297 |
+
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
|
1298 |
+
np_mask = np.array(image_mask)
|
1299 |
+
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
|
1300 |
+
self.mask_for_overlay = Image.fromarray(np_mask)
|
1301 |
+
|
1302 |
+
self.overlay_images = []
|
1303 |
+
|
1304 |
+
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
|
1305 |
+
|
1306 |
+
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
|
1307 |
+
if add_color_corrections:
|
1308 |
+
self.color_corrections = []
|
1309 |
+
imgs = []
|
1310 |
+
for img in self.init_images:
|
1311 |
+
|
1312 |
+
# Save init image
|
1313 |
+
if opts.save_init_img:
|
1314 |
+
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
|
1315 |
+
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
|
1316 |
+
|
1317 |
+
image = images.flatten(img, opts.img2img_background_color)
|
1318 |
+
|
1319 |
+
if crop_region is None and self.resize_mode != 3:
|
1320 |
+
image = images.resize_image(self.resize_mode, image, self.width, self.height)
|
1321 |
+
|
1322 |
+
if image_mask is not None:
|
1323 |
+
image_masked = Image.new('RGBa', (image.width, image.height))
|
1324 |
+
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
|
1325 |
+
|
1326 |
+
self.overlay_images.append(image_masked.convert('RGBA'))
|
1327 |
+
|
1328 |
+
# crop_region is not None if we are doing inpaint full res
|
1329 |
+
if crop_region is not None:
|
1330 |
+
image = image.crop(crop_region)
|
1331 |
+
image = images.resize_image(2, image, self.width, self.height)
|
1332 |
+
|
1333 |
+
if image_mask is not None:
|
1334 |
+
if self.inpainting_fill != 1:
|
1335 |
+
image = masking.fill(image, latent_mask)
|
1336 |
+
|
1337 |
+
if add_color_corrections:
|
1338 |
+
self.color_corrections.append(setup_color_correction(image))
|
1339 |
+
|
1340 |
+
image = np.array(image).astype(np.float32) / 255.0
|
1341 |
+
image = np.moveaxis(image, 2, 0)
|
1342 |
+
|
1343 |
+
imgs.append(image)
|
1344 |
+
|
1345 |
+
if len(imgs) == 1:
|
1346 |
+
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
|
1347 |
+
if self.overlay_images is not None:
|
1348 |
+
self.overlay_images = self.overlay_images * self.batch_size
|
1349 |
+
|
1350 |
+
if self.color_corrections is not None and len(self.color_corrections) == 1:
|
1351 |
+
self.color_corrections = self.color_corrections * self.batch_size
|
1352 |
+
|
1353 |
+
elif len(imgs) <= self.batch_size:
|
1354 |
+
self.batch_size = len(imgs)
|
1355 |
+
batch_images = np.array(imgs)
|
1356 |
+
else:
|
1357 |
+
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
1358 |
+
|
1359 |
+
image = torch.from_numpy(batch_images)
|
1360 |
+
image = 2. * image - 1.
|
1361 |
+
image = image.to(shared.device, dtype=devices.dtype_vae)
|
1362 |
+
|
1363 |
+
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
1364 |
+
|
1365 |
+
if self.resize_mode == 3:
|
1366 |
+
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
1367 |
+
|
1368 |
+
if image_mask is not None:
|
1369 |
+
init_mask = latent_mask
|
1370 |
+
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
1371 |
+
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
1372 |
+
latmask = latmask[0]
|
1373 |
+
latmask = np.around(latmask)
|
1374 |
+
latmask = np.tile(latmask[None], (4, 1, 1))
|
1375 |
+
|
1376 |
+
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
1377 |
+
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
|
1378 |
+
|
1379 |
+
# this needs to be fixed to be done in sample() using actual seeds for batches
|
1380 |
+
if self.inpainting_fill == 2:
|
1381 |
+
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
1382 |
+
elif self.inpainting_fill == 3:
|
1383 |
+
self.init_latent = self.init_latent * self.mask
|
1384 |
+
|
1385 |
+
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
|
1386 |
+
|
1387 |
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
1388 |
+
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
1389 |
+
|
1390 |
+
if self.initial_noise_multiplier != 1.0:
|
1391 |
+
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
1392 |
+
x *= self.initial_noise_multiplier
|
1393 |
+
|
1394 |
+
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
1395 |
+
|
1396 |
+
if self.mask is not None:
|
1397 |
+
samples = samples * self.nmask + self.init_latent * self.mask
|
1398 |
+
|
1399 |
+
del x
|
1400 |
+
devices.torch_gc()
|
1401 |
+
|
1402 |
+
return samples
|
1403 |
+
|
1404 |
+
def get_token_merging_ratio(self, for_hr=False):
|
1405 |
+
return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio
|