toto10 commited on
Commit
5002a4e
·
1 Parent(s): 4cca094

5ddaf0cca1e144f6e16563ce10f39c55b15e5da289b24cf2ef5d01fe66eaa922

Browse files
Files changed (50) hide show
  1. modules/call_queue.py +117 -0
  2. modules/cmd_args.py +113 -0
  3. modules/codeformer/__pycache__/codeformer_arch.cpython-310.pyc +0 -0
  4. modules/codeformer/__pycache__/vqgan_arch.cpython-310.pyc +0 -0
  5. modules/codeformer/codeformer_arch.py +276 -0
  6. modules/codeformer/vqgan_arch.py +435 -0
  7. modules/codeformer_model.py +132 -0
  8. modules/config_states.py +197 -0
  9. modules/deepbooru.py +98 -0
  10. modules/deepbooru_model.py +678 -0
  11. modules/devices.py +171 -0
  12. modules/errors.py +85 -0
  13. modules/esrgan_model.py +229 -0
  14. modules/esrgan_model_arch.py +465 -0
  15. modules/extensions.py +163 -0
  16. modules/extra_networks.py +179 -0
  17. modules/extra_networks_hypernet.py +28 -0
  18. modules/extras.py +303 -0
  19. modules/face_restoration.py +19 -0
  20. modules/generation_parameters_copypaste.py +439 -0
  21. modules/gfpgan_model.py +110 -0
  22. modules/gitpython_hack.py +42 -0
  23. modules/hashes.py +81 -0
  24. modules/hypernetworks/__pycache__/hypernetwork.cpython-310.pyc +0 -0
  25. modules/hypernetworks/__pycache__/ui.cpython-310.pyc +0 -0
  26. modules/hypernetworks/hypernetwork.py +783 -0
  27. modules/hypernetworks/ui.py +38 -0
  28. modules/images.py +758 -0
  29. modules/img2img.py +245 -0
  30. modules/import_hook.py +5 -0
  31. modules/interrogate.py +223 -0
  32. modules/launch_utils.py +415 -0
  33. modules/localization.py +35 -0
  34. modules/lowvram.py +130 -0
  35. modules/mac_specific.py +86 -0
  36. modules/masking.py +99 -0
  37. modules/memmon.py +92 -0
  38. modules/modelloader.py +179 -0
  39. modules/models/diffusion/ddpm_edit.py +1455 -0
  40. modules/models/diffusion/uni_pc/__init__.py +1 -0
  41. modules/models/diffusion/uni_pc/__pycache__/__init__.cpython-310.pyc +0 -0
  42. modules/models/diffusion/uni_pc/__pycache__/sampler.cpython-310.pyc +0 -0
  43. modules/models/diffusion/uni_pc/__pycache__/uni_pc.cpython-310.pyc +0 -0
  44. modules/models/diffusion/uni_pc/sampler.py +101 -0
  45. modules/models/diffusion/uni_pc/uni_pc.py +863 -0
  46. modules/ngrok.py +30 -0
  47. modules/paths.py +65 -0
  48. modules/paths_internal.py +31 -0
  49. modules/postprocessing.py +109 -0
  50. modules/processing.py +1405 -0
modules/call_queue.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ import html
3
+ import threading
4
+ import time
5
+
6
+ from modules import shared, progress, errors
7
+
8
+ queue_lock = threading.Lock()
9
+
10
+
11
+ def wrap_queued_call(func):
12
+ def f(*args, **kwargs):
13
+ with queue_lock:
14
+ res = func(*args, **kwargs)
15
+
16
+ return res
17
+
18
+ return f
19
+
20
+
21
+ def wrap_gradio_gpu_call(func, extra_outputs=None):
22
+ @wraps(func)
23
+ def f(*args, **kwargs):
24
+
25
+ # if the first argument is a string that says "task(...)", it is treated as a job id
26
+ if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
27
+ id_task = args[0]
28
+ progress.add_task_to_queue(id_task)
29
+ else:
30
+ id_task = None
31
+
32
+ with queue_lock:
33
+ shared.state.begin(job=id_task)
34
+ progress.start_task(id_task)
35
+
36
+ try:
37
+ res = func(*args, **kwargs)
38
+ progress.record_results(id_task, res)
39
+ finally:
40
+ progress.finish_task(id_task)
41
+
42
+ shared.state.end()
43
+
44
+ return res
45
+
46
+ return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
47
+
48
+
49
+ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
50
+ @wraps(func)
51
+ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
52
+ run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
53
+ if run_memmon:
54
+ shared.mem_mon.monitor()
55
+ t = time.perf_counter()
56
+
57
+ try:
58
+ res = list(func(*args, **kwargs))
59
+ except Exception as e:
60
+ # When printing out our debug argument list,
61
+ # do not print out more than a 100 KB of text
62
+ max_debug_str_len = 131072
63
+ message = "Error completing request"
64
+ arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
65
+ if len(arg_str) > max_debug_str_len:
66
+ arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
67
+ errors.report(f"{message}\n{arg_str}", exc_info=True)
68
+
69
+ shared.state.job = ""
70
+ shared.state.job_count = 0
71
+
72
+ if extra_outputs_array is None:
73
+ extra_outputs_array = [None, '']
74
+
75
+ error_message = f'{type(e).__name__}: {e}'
76
+ res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
77
+
78
+ shared.state.skipped = False
79
+ shared.state.interrupted = False
80
+ shared.state.job_count = 0
81
+
82
+ if not add_stats:
83
+ return tuple(res)
84
+
85
+ elapsed = time.perf_counter() - t
86
+ elapsed_m = int(elapsed // 60)
87
+ elapsed_s = elapsed % 60
88
+ elapsed_text = f"{elapsed_s:.1f} sec."
89
+ if elapsed_m > 0:
90
+ elapsed_text = f"{elapsed_m} min. "+elapsed_text
91
+
92
+ if run_memmon:
93
+ mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
94
+ active_peak = mem_stats['active_peak']
95
+ reserved_peak = mem_stats['reserved_peak']
96
+ sys_peak = mem_stats['system_peak']
97
+ sys_total = mem_stats['total']
98
+ sys_pct = sys_peak/max(sys_total, 1) * 100
99
+
100
+ toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)"
101
+ toltip_r = "Reserved: total amout of video memory allocated by the Torch library "
102
+ toltip_sys = "System: peak amout of video memory allocated by all running programs, out of total capacity"
103
+
104
+ text_a = f"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>"
105
+ text_r = f"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>"
106
+ text_sys = f"<abbr title='{toltip_sys}'>Sys</abbr>: <span class='measurement'>{sys_peak/1024:.1f}/{sys_total/1024:g} GB</span> ({sys_pct:.1f}%)"
107
+
108
+ vram_html = f"<p class='vram'>{text_a}, <wbr>{text_r}, <wbr>{text_sys}</p>"
109
+ else:
110
+ vram_html = ''
111
+
112
+ # last item is always HTML
113
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}</div>"
114
+
115
+ return tuple(res)
116
+
117
+ return f
modules/cmd_args.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
5
+
6
+ parser = argparse.ArgumentParser()
7
+
8
+ parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
9
+ parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
10
+ parser.add_argument("--skip-python-version-check", action='store_true', help="launch.py argument: do not check python version")
11
+ parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
12
+ parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
13
+ parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
14
+ parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
15
+ parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
16
+ parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
17
+ parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
18
+ parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
19
+ parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
20
+ parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
21
+ parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
22
+ parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
23
+ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
24
+ parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
25
+ parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
26
+ parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
27
+ parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
28
+ parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
29
+ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
30
+ parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
31
+ parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
32
+ parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
33
+ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
34
+ parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
35
+ parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
36
+ parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
37
+ parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
38
+ parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
39
+ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
40
+ parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
41
+ parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
42
+ parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
43
+ parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
44
+ parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
45
+ parser.add_argument("--ngrok-region", type=str, help="does not do anything.", default="")
46
+ parser.add_argument("--ngrok-options", type=json.loads, help='The options to pass to ngrok in JSON format, e.g.: \'{"authtoken_from_env":true, "basic_auth":"user:password", "oauth_provider":"google", "oauth_allow_emails":"user@asdf.com"}\'', default=dict())
47
+ parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
48
+ parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
49
+ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
50
+ parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
51
+ parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
52
+ parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
53
+ parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
54
+ parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
55
+ parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
56
+ parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
57
+ parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
58
+ parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
59
+ parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
60
+ parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
61
+ parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
62
+ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
63
+ parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
64
+ parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
65
+ parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
66
+ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
67
+ parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
68
+ parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
69
+ parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
70
+ parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
71
+ parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
72
+ parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
73
+ parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
74
+ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
75
+ parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
76
+ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
77
+ parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
78
+ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
79
+ parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
80
+ parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
81
+ parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
82
+ parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it")
83
+ parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
84
+ parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
85
+ parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
86
+ parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
87
+ parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
88
+ parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
89
+ parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
90
+ parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
91
+ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
92
+ parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
93
+ parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
94
+ parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
95
+ parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
96
+ parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
97
+ parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
98
+ parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
99
+ parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
100
+ parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
101
+ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
102
+ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
103
+ parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
104
+ parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
105
+ parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
106
+ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
107
+ parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
108
+ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
109
+ parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
110
+ parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
111
+ parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
112
+ parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
113
+ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
modules/codeformer/__pycache__/codeformer_arch.cpython-310.pyc ADDED
Binary file (9.16 kB). View file
 
modules/codeformer/__pycache__/vqgan_arch.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
modules/codeformer/codeformer_arch.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
2
+
3
+ import math
4
+ import torch
5
+ from torch import nn, Tensor
6
+ import torch.nn.functional as F
7
+ from typing import Optional
8
+
9
+ from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
10
+ from basicsr.utils.registry import ARCH_REGISTRY
11
+
12
+ def calc_mean_std(feat, eps=1e-5):
13
+ """Calculate mean and std for adaptive_instance_normalization.
14
+
15
+ Args:
16
+ feat (Tensor): 4D tensor.
17
+ eps (float): A small value added to the variance to avoid
18
+ divide-by-zero. Default: 1e-5.
19
+ """
20
+ size = feat.size()
21
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
22
+ b, c = size[:2]
23
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
24
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
25
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
26
+ return feat_mean, feat_std
27
+
28
+
29
+ def adaptive_instance_normalization(content_feat, style_feat):
30
+ """Adaptive instance normalization.
31
+
32
+ Adjust the reference features to have the similar color and illuminations
33
+ as those in the degradate features.
34
+
35
+ Args:
36
+ content_feat (Tensor): The reference feature.
37
+ style_feat (Tensor): The degradate features.
38
+ """
39
+ size = content_feat.size()
40
+ style_mean, style_std = calc_mean_std(style_feat)
41
+ content_mean, content_std = calc_mean_std(content_feat)
42
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
43
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
44
+
45
+
46
+ class PositionEmbeddingSine(nn.Module):
47
+ """
48
+ This is a more standard version of the position embedding, very similar to the one
49
+ used by the Attention is all you need paper, generalized to work on images.
50
+ """
51
+
52
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
53
+ super().__init__()
54
+ self.num_pos_feats = num_pos_feats
55
+ self.temperature = temperature
56
+ self.normalize = normalize
57
+ if scale is not None and normalize is False:
58
+ raise ValueError("normalize should be True if scale is passed")
59
+ if scale is None:
60
+ scale = 2 * math.pi
61
+ self.scale = scale
62
+
63
+ def forward(self, x, mask=None):
64
+ if mask is None:
65
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
66
+ not_mask = ~mask
67
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
68
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
69
+ if self.normalize:
70
+ eps = 1e-6
71
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
72
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
73
+
74
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
75
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
76
+
77
+ pos_x = x_embed[:, :, :, None] / dim_t
78
+ pos_y = y_embed[:, :, :, None] / dim_t
79
+ pos_x = torch.stack(
80
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
81
+ ).flatten(3)
82
+ pos_y = torch.stack(
83
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
84
+ ).flatten(3)
85
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
86
+ return pos
87
+
88
+ def _get_activation_fn(activation):
89
+ """Return an activation function given a string"""
90
+ if activation == "relu":
91
+ return F.relu
92
+ if activation == "gelu":
93
+ return F.gelu
94
+ if activation == "glu":
95
+ return F.glu
96
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
97
+
98
+
99
+ class TransformerSALayer(nn.Module):
100
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
101
+ super().__init__()
102
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
103
+ # Implementation of Feedforward model - MLP
104
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
105
+ self.dropout = nn.Dropout(dropout)
106
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
107
+
108
+ self.norm1 = nn.LayerNorm(embed_dim)
109
+ self.norm2 = nn.LayerNorm(embed_dim)
110
+ self.dropout1 = nn.Dropout(dropout)
111
+ self.dropout2 = nn.Dropout(dropout)
112
+
113
+ self.activation = _get_activation_fn(activation)
114
+
115
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
116
+ return tensor if pos is None else tensor + pos
117
+
118
+ def forward(self, tgt,
119
+ tgt_mask: Optional[Tensor] = None,
120
+ tgt_key_padding_mask: Optional[Tensor] = None,
121
+ query_pos: Optional[Tensor] = None):
122
+
123
+ # self attention
124
+ tgt2 = self.norm1(tgt)
125
+ q = k = self.with_pos_embed(tgt2, query_pos)
126
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
127
+ key_padding_mask=tgt_key_padding_mask)[0]
128
+ tgt = tgt + self.dropout1(tgt2)
129
+
130
+ # ffn
131
+ tgt2 = self.norm2(tgt)
132
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
133
+ tgt = tgt + self.dropout2(tgt2)
134
+ return tgt
135
+
136
+ class Fuse_sft_block(nn.Module):
137
+ def __init__(self, in_ch, out_ch):
138
+ super().__init__()
139
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
140
+
141
+ self.scale = nn.Sequential(
142
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
143
+ nn.LeakyReLU(0.2, True),
144
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
145
+
146
+ self.shift = nn.Sequential(
147
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
148
+ nn.LeakyReLU(0.2, True),
149
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
150
+
151
+ def forward(self, enc_feat, dec_feat, w=1):
152
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
153
+ scale = self.scale(enc_feat)
154
+ shift = self.shift(enc_feat)
155
+ residual = w * (dec_feat * scale + shift)
156
+ out = dec_feat + residual
157
+ return out
158
+
159
+
160
+ @ARCH_REGISTRY.register()
161
+ class CodeFormer(VQAutoEncoder):
162
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
163
+ codebook_size=1024, latent_size=256,
164
+ connect_list=('32', '64', '128', '256'),
165
+ fix_modules=('quantize', 'generator')):
166
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
167
+
168
+ if fix_modules is not None:
169
+ for module in fix_modules:
170
+ for param in getattr(self, module).parameters():
171
+ param.requires_grad = False
172
+
173
+ self.connect_list = connect_list
174
+ self.n_layers = n_layers
175
+ self.dim_embd = dim_embd
176
+ self.dim_mlp = dim_embd*2
177
+
178
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
179
+ self.feat_emb = nn.Linear(256, self.dim_embd)
180
+
181
+ # transformer
182
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
183
+ for _ in range(self.n_layers)])
184
+
185
+ # logits_predict head
186
+ self.idx_pred_layer = nn.Sequential(
187
+ nn.LayerNorm(dim_embd),
188
+ nn.Linear(dim_embd, codebook_size, bias=False))
189
+
190
+ self.channels = {
191
+ '16': 512,
192
+ '32': 256,
193
+ '64': 256,
194
+ '128': 128,
195
+ '256': 128,
196
+ '512': 64,
197
+ }
198
+
199
+ # after second residual block for > 16, before attn layer for ==16
200
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
201
+ # after first residual block for > 16, before attn layer for ==16
202
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
203
+
204
+ # fuse_convs_dict
205
+ self.fuse_convs_dict = nn.ModuleDict()
206
+ for f_size in self.connect_list:
207
+ in_ch = self.channels[f_size]
208
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
209
+
210
+ def _init_weights(self, module):
211
+ if isinstance(module, (nn.Linear, nn.Embedding)):
212
+ module.weight.data.normal_(mean=0.0, std=0.02)
213
+ if isinstance(module, nn.Linear) and module.bias is not None:
214
+ module.bias.data.zero_()
215
+ elif isinstance(module, nn.LayerNorm):
216
+ module.bias.data.zero_()
217
+ module.weight.data.fill_(1.0)
218
+
219
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
220
+ # ################### Encoder #####################
221
+ enc_feat_dict = {}
222
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
223
+ for i, block in enumerate(self.encoder.blocks):
224
+ x = block(x)
225
+ if i in out_list:
226
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
227
+
228
+ lq_feat = x
229
+ # ################# Transformer ###################
230
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
231
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
232
+ # BCHW -> BC(HW) -> (HW)BC
233
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
234
+ query_emb = feat_emb
235
+ # Transformer encoder
236
+ for layer in self.ft_layers:
237
+ query_emb = layer(query_emb, query_pos=pos_emb)
238
+
239
+ # output logits
240
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
241
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
242
+
243
+ if code_only: # for training stage II
244
+ # logits doesn't need softmax before cross_entropy loss
245
+ return logits, lq_feat
246
+
247
+ # ################# Quantization ###################
248
+ # if self.training:
249
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
250
+ # # b(hw)c -> bc(hw) -> bchw
251
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
252
+ # ------------
253
+ soft_one_hot = F.softmax(logits, dim=2)
254
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
255
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
256
+ # preserve gradients
257
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
258
+
259
+ if detach_16:
260
+ quant_feat = quant_feat.detach() # for training stage III
261
+ if adain:
262
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
263
+
264
+ # ################## Generator ####################
265
+ x = quant_feat
266
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
267
+
268
+ for i, block in enumerate(self.generator.blocks):
269
+ x = block(x)
270
+ if i in fuse_list: # fuse after i-th block
271
+ f_size = str(x.shape[-1])
272
+ if w>0:
273
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
274
+ out = x
275
+ # logits doesn't need softmax before cross_entropy loss
276
+ return out, logits, lq_feat
modules/codeformer/vqgan_arch.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
2
+
3
+ '''
4
+ VQGAN code, adapted from the original created by the Unleashing Transformers authors:
5
+ https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
6
+
7
+ '''
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from basicsr.utils import get_root_logger
12
+ from basicsr.utils.registry import ARCH_REGISTRY
13
+
14
+ def normalize(in_channels):
15
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
16
+
17
+
18
+ @torch.jit.script
19
+ def swish(x):
20
+ return x*torch.sigmoid(x)
21
+
22
+
23
+ # Define VQVAE classes
24
+ class VectorQuantizer(nn.Module):
25
+ def __init__(self, codebook_size, emb_dim, beta):
26
+ super(VectorQuantizer, self).__init__()
27
+ self.codebook_size = codebook_size # number of embeddings
28
+ self.emb_dim = emb_dim # dimension of embedding
29
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
30
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
31
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
32
+
33
+ def forward(self, z):
34
+ # reshape z -> (batch, height, width, channel) and flatten
35
+ z = z.permute(0, 2, 3, 1).contiguous()
36
+ z_flattened = z.view(-1, self.emb_dim)
37
+
38
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
39
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
40
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
41
+
42
+ mean_distance = torch.mean(d)
43
+ # find closest encodings
44
+ # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
45
+ min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
46
+ # [0-1], higher score, higher confidence
47
+ min_encoding_scores = torch.exp(-min_encoding_scores/10)
48
+
49
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
50
+ min_encodings.scatter_(1, min_encoding_indices, 1)
51
+
52
+ # get quantized latent vectors
53
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
54
+ # compute loss for embedding
55
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
56
+ # preserve gradients
57
+ z_q = z + (z_q - z).detach()
58
+
59
+ # perplexity
60
+ e_mean = torch.mean(min_encodings, dim=0)
61
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
62
+ # reshape back to match original input shape
63
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
64
+
65
+ return z_q, loss, {
66
+ "perplexity": perplexity,
67
+ "min_encodings": min_encodings,
68
+ "min_encoding_indices": min_encoding_indices,
69
+ "min_encoding_scores": min_encoding_scores,
70
+ "mean_distance": mean_distance
71
+ }
72
+
73
+ def get_codebook_feat(self, indices, shape):
74
+ # input indices: batch*token_num -> (batch*token_num)*1
75
+ # shape: batch, height, width, channel
76
+ indices = indices.view(-1,1)
77
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
78
+ min_encodings.scatter_(1, indices, 1)
79
+ # get quantized latent vectors
80
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
81
+
82
+ if shape is not None: # reshape back to match original input shape
83
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
84
+
85
+ return z_q
86
+
87
+
88
+ class GumbelQuantizer(nn.Module):
89
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
90
+ super().__init__()
91
+ self.codebook_size = codebook_size # number of embeddings
92
+ self.emb_dim = emb_dim # dimension of embedding
93
+ self.straight_through = straight_through
94
+ self.temperature = temp_init
95
+ self.kl_weight = kl_weight
96
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
97
+ self.embed = nn.Embedding(codebook_size, emb_dim)
98
+
99
+ def forward(self, z):
100
+ hard = self.straight_through if self.training else True
101
+
102
+ logits = self.proj(z)
103
+
104
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
105
+
106
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
107
+
108
+ # + kl divergence to the prior loss
109
+ qy = F.softmax(logits, dim=1)
110
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
111
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
112
+
113
+ return z_q, diff, {
114
+ "min_encoding_indices": min_encoding_indices
115
+ }
116
+
117
+
118
+ class Downsample(nn.Module):
119
+ def __init__(self, in_channels):
120
+ super().__init__()
121
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
122
+
123
+ def forward(self, x):
124
+ pad = (0, 1, 0, 1)
125
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
126
+ x = self.conv(x)
127
+ return x
128
+
129
+
130
+ class Upsample(nn.Module):
131
+ def __init__(self, in_channels):
132
+ super().__init__()
133
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
134
+
135
+ def forward(self, x):
136
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
137
+ x = self.conv(x)
138
+
139
+ return x
140
+
141
+
142
+ class ResBlock(nn.Module):
143
+ def __init__(self, in_channels, out_channels=None):
144
+ super(ResBlock, self).__init__()
145
+ self.in_channels = in_channels
146
+ self.out_channels = in_channels if out_channels is None else out_channels
147
+ self.norm1 = normalize(in_channels)
148
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
149
+ self.norm2 = normalize(out_channels)
150
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
151
+ if self.in_channels != self.out_channels:
152
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
153
+
154
+ def forward(self, x_in):
155
+ x = x_in
156
+ x = self.norm1(x)
157
+ x = swish(x)
158
+ x = self.conv1(x)
159
+ x = self.norm2(x)
160
+ x = swish(x)
161
+ x = self.conv2(x)
162
+ if self.in_channels != self.out_channels:
163
+ x_in = self.conv_out(x_in)
164
+
165
+ return x + x_in
166
+
167
+
168
+ class AttnBlock(nn.Module):
169
+ def __init__(self, in_channels):
170
+ super().__init__()
171
+ self.in_channels = in_channels
172
+
173
+ self.norm = normalize(in_channels)
174
+ self.q = torch.nn.Conv2d(
175
+ in_channels,
176
+ in_channels,
177
+ kernel_size=1,
178
+ stride=1,
179
+ padding=0
180
+ )
181
+ self.k = torch.nn.Conv2d(
182
+ in_channels,
183
+ in_channels,
184
+ kernel_size=1,
185
+ stride=1,
186
+ padding=0
187
+ )
188
+ self.v = torch.nn.Conv2d(
189
+ in_channels,
190
+ in_channels,
191
+ kernel_size=1,
192
+ stride=1,
193
+ padding=0
194
+ )
195
+ self.proj_out = torch.nn.Conv2d(
196
+ in_channels,
197
+ in_channels,
198
+ kernel_size=1,
199
+ stride=1,
200
+ padding=0
201
+ )
202
+
203
+ def forward(self, x):
204
+ h_ = x
205
+ h_ = self.norm(h_)
206
+ q = self.q(h_)
207
+ k = self.k(h_)
208
+ v = self.v(h_)
209
+
210
+ # compute attention
211
+ b, c, h, w = q.shape
212
+ q = q.reshape(b, c, h*w)
213
+ q = q.permute(0, 2, 1)
214
+ k = k.reshape(b, c, h*w)
215
+ w_ = torch.bmm(q, k)
216
+ w_ = w_ * (int(c)**(-0.5))
217
+ w_ = F.softmax(w_, dim=2)
218
+
219
+ # attend to values
220
+ v = v.reshape(b, c, h*w)
221
+ w_ = w_.permute(0, 2, 1)
222
+ h_ = torch.bmm(v, w_)
223
+ h_ = h_.reshape(b, c, h, w)
224
+
225
+ h_ = self.proj_out(h_)
226
+
227
+ return x+h_
228
+
229
+
230
+ class Encoder(nn.Module):
231
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
232
+ super().__init__()
233
+ self.nf = nf
234
+ self.num_resolutions = len(ch_mult)
235
+ self.num_res_blocks = num_res_blocks
236
+ self.resolution = resolution
237
+ self.attn_resolutions = attn_resolutions
238
+
239
+ curr_res = self.resolution
240
+ in_ch_mult = (1,)+tuple(ch_mult)
241
+
242
+ blocks = []
243
+ # initial convultion
244
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
245
+
246
+ # residual and downsampling blocks, with attention on smaller res (16x16)
247
+ for i in range(self.num_resolutions):
248
+ block_in_ch = nf * in_ch_mult[i]
249
+ block_out_ch = nf * ch_mult[i]
250
+ for _ in range(self.num_res_blocks):
251
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
252
+ block_in_ch = block_out_ch
253
+ if curr_res in attn_resolutions:
254
+ blocks.append(AttnBlock(block_in_ch))
255
+
256
+ if i != self.num_resolutions - 1:
257
+ blocks.append(Downsample(block_in_ch))
258
+ curr_res = curr_res // 2
259
+
260
+ # non-local attention block
261
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
262
+ blocks.append(AttnBlock(block_in_ch))
263
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
264
+
265
+ # normalise and convert to latent size
266
+ blocks.append(normalize(block_in_ch))
267
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
268
+ self.blocks = nn.ModuleList(blocks)
269
+
270
+ def forward(self, x):
271
+ for block in self.blocks:
272
+ x = block(x)
273
+
274
+ return x
275
+
276
+
277
+ class Generator(nn.Module):
278
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
279
+ super().__init__()
280
+ self.nf = nf
281
+ self.ch_mult = ch_mult
282
+ self.num_resolutions = len(self.ch_mult)
283
+ self.num_res_blocks = res_blocks
284
+ self.resolution = img_size
285
+ self.attn_resolutions = attn_resolutions
286
+ self.in_channels = emb_dim
287
+ self.out_channels = 3
288
+ block_in_ch = self.nf * self.ch_mult[-1]
289
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
290
+
291
+ blocks = []
292
+ # initial conv
293
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
294
+
295
+ # non-local attention block
296
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
297
+ blocks.append(AttnBlock(block_in_ch))
298
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
299
+
300
+ for i in reversed(range(self.num_resolutions)):
301
+ block_out_ch = self.nf * self.ch_mult[i]
302
+
303
+ for _ in range(self.num_res_blocks):
304
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
305
+ block_in_ch = block_out_ch
306
+
307
+ if curr_res in self.attn_resolutions:
308
+ blocks.append(AttnBlock(block_in_ch))
309
+
310
+ if i != 0:
311
+ blocks.append(Upsample(block_in_ch))
312
+ curr_res = curr_res * 2
313
+
314
+ blocks.append(normalize(block_in_ch))
315
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
316
+
317
+ self.blocks = nn.ModuleList(blocks)
318
+
319
+
320
+ def forward(self, x):
321
+ for block in self.blocks:
322
+ x = block(x)
323
+
324
+ return x
325
+
326
+
327
+ @ARCH_REGISTRY.register()
328
+ class VQAutoEncoder(nn.Module):
329
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
330
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
331
+ super().__init__()
332
+ logger = get_root_logger()
333
+ self.in_channels = 3
334
+ self.nf = nf
335
+ self.n_blocks = res_blocks
336
+ self.codebook_size = codebook_size
337
+ self.embed_dim = emb_dim
338
+ self.ch_mult = ch_mult
339
+ self.resolution = img_size
340
+ self.attn_resolutions = attn_resolutions or [16]
341
+ self.quantizer_type = quantizer
342
+ self.encoder = Encoder(
343
+ self.in_channels,
344
+ self.nf,
345
+ self.embed_dim,
346
+ self.ch_mult,
347
+ self.n_blocks,
348
+ self.resolution,
349
+ self.attn_resolutions
350
+ )
351
+ if self.quantizer_type == "nearest":
352
+ self.beta = beta #0.25
353
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
354
+ elif self.quantizer_type == "gumbel":
355
+ self.gumbel_num_hiddens = emb_dim
356
+ self.straight_through = gumbel_straight_through
357
+ self.kl_weight = gumbel_kl_weight
358
+ self.quantize = GumbelQuantizer(
359
+ self.codebook_size,
360
+ self.embed_dim,
361
+ self.gumbel_num_hiddens,
362
+ self.straight_through,
363
+ self.kl_weight
364
+ )
365
+ self.generator = Generator(
366
+ self.nf,
367
+ self.embed_dim,
368
+ self.ch_mult,
369
+ self.n_blocks,
370
+ self.resolution,
371
+ self.attn_resolutions
372
+ )
373
+
374
+ if model_path is not None:
375
+ chkpt = torch.load(model_path, map_location='cpu')
376
+ if 'params_ema' in chkpt:
377
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
378
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
379
+ elif 'params' in chkpt:
380
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
381
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
382
+ else:
383
+ raise ValueError('Wrong params!')
384
+
385
+
386
+ def forward(self, x):
387
+ x = self.encoder(x)
388
+ quant, codebook_loss, quant_stats = self.quantize(x)
389
+ x = self.generator(quant)
390
+ return x, codebook_loss, quant_stats
391
+
392
+
393
+
394
+ # patch based discriminator
395
+ @ARCH_REGISTRY.register()
396
+ class VQGANDiscriminator(nn.Module):
397
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
398
+ super().__init__()
399
+
400
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
401
+ ndf_mult = 1
402
+ ndf_mult_prev = 1
403
+ for n in range(1, n_layers): # gradually increase the number of filters
404
+ ndf_mult_prev = ndf_mult
405
+ ndf_mult = min(2 ** n, 8)
406
+ layers += [
407
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
408
+ nn.BatchNorm2d(ndf * ndf_mult),
409
+ nn.LeakyReLU(0.2, True)
410
+ ]
411
+
412
+ ndf_mult_prev = ndf_mult
413
+ ndf_mult = min(2 ** n_layers, 8)
414
+
415
+ layers += [
416
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
417
+ nn.BatchNorm2d(ndf * ndf_mult),
418
+ nn.LeakyReLU(0.2, True)
419
+ ]
420
+
421
+ layers += [
422
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
423
+ self.main = nn.Sequential(*layers)
424
+
425
+ if model_path is not None:
426
+ chkpt = torch.load(model_path, map_location='cpu')
427
+ if 'params_d' in chkpt:
428
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
429
+ elif 'params' in chkpt:
430
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
431
+ else:
432
+ raise ValueError('Wrong params!')
433
+
434
+ def forward(self, x):
435
+ return self.main(x)
modules/codeformer_model.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import torch
5
+
6
+ import modules.face_restoration
7
+ import modules.shared
8
+ from modules import shared, devices, modelloader, errors
9
+ from modules.paths import models_path
10
+
11
+ # codeformer people made a choice to include modified basicsr library to their project which makes
12
+ # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
13
+ # I am making a choice to include some files from codeformer to work around this issue.
14
+ model_dir = "Codeformer"
15
+ model_path = os.path.join(models_path, model_dir)
16
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
17
+
18
+ codeformer = None
19
+
20
+
21
+ def setup_model(dirname):
22
+ os.makedirs(model_path, exist_ok=True)
23
+
24
+ path = modules.paths.paths.get("CodeFormer", None)
25
+ if path is None:
26
+ return
27
+
28
+ try:
29
+ from torchvision.transforms.functional import normalize
30
+ from modules.codeformer.codeformer_arch import CodeFormer
31
+ from basicsr.utils import img2tensor, tensor2img
32
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
33
+ from facelib.detection.retinaface import retinaface
34
+
35
+ net_class = CodeFormer
36
+
37
+ class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
38
+ def name(self):
39
+ return "CodeFormer"
40
+
41
+ def __init__(self, dirname):
42
+ self.net = None
43
+ self.face_helper = None
44
+ self.cmd_dir = dirname
45
+
46
+ def create_models(self):
47
+
48
+ if self.net is not None and self.face_helper is not None:
49
+ self.net.to(devices.device_codeformer)
50
+ return self.net, self.face_helper
51
+ model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
52
+ if len(model_paths) != 0:
53
+ ckpt_path = model_paths[0]
54
+ else:
55
+ print("Unable to load codeformer model.")
56
+ return None, None
57
+ net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
58
+ checkpoint = torch.load(ckpt_path)['params_ema']
59
+ net.load_state_dict(checkpoint)
60
+ net.eval()
61
+
62
+ if hasattr(retinaface, 'device'):
63
+ retinaface.device = devices.device_codeformer
64
+ face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
65
+
66
+ self.net = net
67
+ self.face_helper = face_helper
68
+
69
+ return net, face_helper
70
+
71
+ def send_model_to(self, device):
72
+ self.net.to(device)
73
+ self.face_helper.face_det.to(device)
74
+ self.face_helper.face_parse.to(device)
75
+
76
+ def restore(self, np_image, w=None):
77
+ np_image = np_image[:, :, ::-1]
78
+
79
+ original_resolution = np_image.shape[0:2]
80
+
81
+ self.create_models()
82
+ if self.net is None or self.face_helper is None:
83
+ return np_image
84
+
85
+ self.send_model_to(devices.device_codeformer)
86
+
87
+ self.face_helper.clean_all()
88
+ self.face_helper.read_image(np_image)
89
+ self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
90
+ self.face_helper.align_warp_face()
91
+
92
+ for cropped_face in self.face_helper.cropped_faces:
93
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
94
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
95
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
96
+
97
+ try:
98
+ with torch.no_grad():
99
+ output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
100
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
101
+ del output
102
+ devices.torch_gc()
103
+ except Exception:
104
+ errors.report('Failed inference for CodeFormer', exc_info=True)
105
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
106
+
107
+ restored_face = restored_face.astype('uint8')
108
+ self.face_helper.add_restored_face(restored_face)
109
+
110
+ self.face_helper.get_inverse_affine(None)
111
+
112
+ restored_img = self.face_helper.paste_faces_to_input_image()
113
+ restored_img = restored_img[:, :, ::-1]
114
+
115
+ if original_resolution != restored_img.shape[0:2]:
116
+ restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
117
+
118
+ self.face_helper.clean_all()
119
+
120
+ if shared.opts.face_restoration_unload:
121
+ self.send_model_to(devices.cpu)
122
+
123
+ return restored_img
124
+
125
+ global codeformer
126
+ codeformer = FaceRestorerCodeFormer(dirname)
127
+ shared.face_restorers.append(codeformer)
128
+
129
+ except Exception:
130
+ errors.report("Error setting up CodeFormer", exc_info=True)
131
+
132
+ # sys.path = stored_sys_path
modules/config_states.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Supports saving and restoring webui and extensions from a known working set of commits
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import time
8
+ import tqdm
9
+
10
+ from datetime import datetime
11
+ from collections import OrderedDict
12
+ import git
13
+
14
+ from modules import shared, extensions, errors
15
+ from modules.paths_internal import script_path, config_states_dir
16
+
17
+
18
+ all_config_states = OrderedDict()
19
+
20
+
21
+ def list_config_states():
22
+ global all_config_states
23
+
24
+ all_config_states.clear()
25
+ os.makedirs(config_states_dir, exist_ok=True)
26
+
27
+ config_states = []
28
+ for filename in os.listdir(config_states_dir):
29
+ if filename.endswith(".json"):
30
+ path = os.path.join(config_states_dir, filename)
31
+ with open(path, "r", encoding="utf-8") as f:
32
+ j = json.load(f)
33
+ j["filepath"] = path
34
+ config_states.append(j)
35
+
36
+ config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
37
+
38
+ for cs in config_states:
39
+ timestamp = time.asctime(time.gmtime(cs["created_at"]))
40
+ name = cs.get("name", "Config")
41
+ full_name = f"{name}: {timestamp}"
42
+ all_config_states[full_name] = cs
43
+
44
+ return all_config_states
45
+
46
+
47
+ def get_webui_config():
48
+ webui_repo = None
49
+
50
+ try:
51
+ if os.path.exists(os.path.join(script_path, ".git")):
52
+ webui_repo = git.Repo(script_path)
53
+ except Exception:
54
+ errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
55
+
56
+ webui_remote = None
57
+ webui_commit_hash = None
58
+ webui_commit_date = None
59
+ webui_branch = None
60
+ if webui_repo and not webui_repo.bare:
61
+ try:
62
+ webui_remote = next(webui_repo.remote().urls, None)
63
+ head = webui_repo.head.commit
64
+ webui_commit_date = webui_repo.head.commit.committed_date
65
+ webui_commit_hash = head.hexsha
66
+ webui_branch = webui_repo.active_branch.name
67
+
68
+ except Exception:
69
+ webui_remote = None
70
+
71
+ return {
72
+ "remote": webui_remote,
73
+ "commit_hash": webui_commit_hash,
74
+ "commit_date": webui_commit_date,
75
+ "branch": webui_branch,
76
+ }
77
+
78
+
79
+ def get_extension_config():
80
+ ext_config = {}
81
+
82
+ for ext in extensions.extensions:
83
+ ext.read_info_from_repo()
84
+
85
+ entry = {
86
+ "name": ext.name,
87
+ "path": ext.path,
88
+ "enabled": ext.enabled,
89
+ "is_builtin": ext.is_builtin,
90
+ "remote": ext.remote,
91
+ "commit_hash": ext.commit_hash,
92
+ "commit_date": ext.commit_date,
93
+ "branch": ext.branch,
94
+ "have_info_from_repo": ext.have_info_from_repo
95
+ }
96
+
97
+ ext_config[ext.name] = entry
98
+
99
+ return ext_config
100
+
101
+
102
+ def get_config():
103
+ creation_time = datetime.now().timestamp()
104
+ webui_config = get_webui_config()
105
+ ext_config = get_extension_config()
106
+
107
+ return {
108
+ "created_at": creation_time,
109
+ "webui": webui_config,
110
+ "extensions": ext_config
111
+ }
112
+
113
+
114
+ def restore_webui_config(config):
115
+ print("* Restoring webui state...")
116
+
117
+ if "webui" not in config:
118
+ print("Error: No webui data saved to config")
119
+ return
120
+
121
+ webui_config = config["webui"]
122
+
123
+ if "commit_hash" not in webui_config:
124
+ print("Error: No commit saved to webui config")
125
+ return
126
+
127
+ webui_commit_hash = webui_config.get("commit_hash", None)
128
+ webui_repo = None
129
+
130
+ try:
131
+ if os.path.exists(os.path.join(script_path, ".git")):
132
+ webui_repo = git.Repo(script_path)
133
+ except Exception:
134
+ errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
135
+ return
136
+
137
+ try:
138
+ webui_repo.git.fetch(all=True)
139
+ webui_repo.git.reset(webui_commit_hash, hard=True)
140
+ print(f"* Restored webui to commit {webui_commit_hash}.")
141
+ except Exception:
142
+ errors.report(f"Error restoring webui to commit{webui_commit_hash}")
143
+
144
+
145
+ def restore_extension_config(config):
146
+ print("* Restoring extension state...")
147
+
148
+ if "extensions" not in config:
149
+ print("Error: No extension data saved to config")
150
+ return
151
+
152
+ ext_config = config["extensions"]
153
+
154
+ results = []
155
+ disabled = []
156
+
157
+ for ext in tqdm.tqdm(extensions.extensions):
158
+ if ext.is_builtin:
159
+ continue
160
+
161
+ ext.read_info_from_repo()
162
+ current_commit = ext.commit_hash
163
+
164
+ if ext.name not in ext_config:
165
+ ext.disabled = True
166
+ disabled.append(ext.name)
167
+ results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
168
+ continue
169
+
170
+ entry = ext_config[ext.name]
171
+
172
+ if "commit_hash" in entry and entry["commit_hash"]:
173
+ try:
174
+ ext.fetch_and_reset_hard(entry["commit_hash"])
175
+ ext.read_info_from_repo()
176
+ if current_commit != entry["commit_hash"]:
177
+ results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
178
+ except Exception as ex:
179
+ results.append((ext, current_commit[:8], False, ex))
180
+ else:
181
+ results.append((ext, current_commit[:8], False, "No commit hash found in config"))
182
+
183
+ if not entry.get("enabled", False):
184
+ ext.disabled = True
185
+ disabled.append(ext.name)
186
+ else:
187
+ ext.disabled = False
188
+
189
+ shared.opts.disabled_extensions = disabled
190
+ shared.opts.save(shared.config_filename)
191
+
192
+ print("* Finished restoring extensions. Results:")
193
+ for ext, prev_commit, success, result in results:
194
+ if success:
195
+ print(f" + {ext.name}: {prev_commit} -> {result}")
196
+ else:
197
+ print(f" ! {ext.name}: FAILURE ({result})")
modules/deepbooru.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ import numpy as np
6
+
7
+ from modules import modelloader, paths, deepbooru_model, devices, images, shared
8
+
9
+ re_special = re.compile(r'([\\()])')
10
+
11
+
12
+ class DeepDanbooru:
13
+ def __init__(self):
14
+ self.model = None
15
+
16
+ def load(self):
17
+ if self.model is not None:
18
+ return
19
+
20
+ files = modelloader.load_models(
21
+ model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
22
+ model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
23
+ ext_filter=[".pt"],
24
+ download_name='model-resnet_custom_v3.pt',
25
+ )
26
+
27
+ self.model = deepbooru_model.DeepDanbooruModel()
28
+ self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
29
+
30
+ self.model.eval()
31
+ self.model.to(devices.cpu, devices.dtype)
32
+
33
+ def start(self):
34
+ self.load()
35
+ self.model.to(devices.device)
36
+
37
+ def stop(self):
38
+ if not shared.opts.interrogate_keep_models_in_memory:
39
+ self.model.to(devices.cpu)
40
+ devices.torch_gc()
41
+
42
+ def tag(self, pil_image):
43
+ self.start()
44
+ res = self.tag_multi(pil_image)
45
+ self.stop()
46
+
47
+ return res
48
+
49
+ def tag_multi(self, pil_image, force_disable_ranks=False):
50
+ threshold = shared.opts.interrogate_deepbooru_score_threshold
51
+ use_spaces = shared.opts.deepbooru_use_spaces
52
+ use_escape = shared.opts.deepbooru_escape
53
+ alpha_sort = shared.opts.deepbooru_sort_alpha
54
+ include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
55
+
56
+ pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
57
+ a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
58
+
59
+ with torch.no_grad(), devices.autocast():
60
+ x = torch.from_numpy(a).to(devices.device)
61
+ y = self.model(x)[0].detach().cpu().numpy()
62
+
63
+ probability_dict = {}
64
+
65
+ for tag, probability in zip(self.model.tags, y):
66
+ if probability < threshold:
67
+ continue
68
+
69
+ if tag.startswith("rating:"):
70
+ continue
71
+
72
+ probability_dict[tag] = probability
73
+
74
+ if alpha_sort:
75
+ tags = sorted(probability_dict)
76
+ else:
77
+ tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
78
+
79
+ res = []
80
+
81
+ filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
82
+
83
+ for tag in [x for x in tags if x not in filtertags]:
84
+ probability = probability_dict[tag]
85
+ tag_outformat = tag
86
+ if use_spaces:
87
+ tag_outformat = tag_outformat.replace('_', ' ')
88
+ if use_escape:
89
+ tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
90
+ if include_ranks:
91
+ tag_outformat = f"({tag_outformat}:{probability:.3f})"
92
+
93
+ res.append(tag_outformat)
94
+
95
+ return ", ".join(res)
96
+
97
+
98
+ model = DeepDanbooru()
modules/deepbooru_model.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from modules import devices
6
+
7
+ # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
8
+
9
+
10
+ class DeepDanbooruModel(nn.Module):
11
+ def __init__(self):
12
+ super(DeepDanbooruModel, self).__init__()
13
+
14
+ self.tags = []
15
+
16
+ self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
17
+ self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
18
+ self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
19
+ self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
20
+ self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
21
+ self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
22
+ self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
23
+ self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
24
+ self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
25
+ self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
26
+ self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
27
+ self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
28
+ self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
29
+ self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
30
+ self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
31
+ self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
32
+ self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
33
+ self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
34
+ self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
35
+ self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
36
+ self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
37
+ self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
38
+ self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
39
+ self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
40
+ self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
41
+ self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
42
+ self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
43
+ self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
44
+ self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
45
+ self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
46
+ self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
47
+ self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
48
+ self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
49
+ self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
50
+ self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
51
+ self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
52
+ self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
53
+ self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
54
+ self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
55
+ self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
56
+ self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
57
+ self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
58
+ self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
59
+ self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
60
+ self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
61
+ self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
62
+ self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
63
+ self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
64
+ self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
65
+ self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
66
+ self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
67
+ self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
68
+ self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
69
+ self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
70
+ self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
71
+ self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
72
+ self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
73
+ self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
74
+ self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
75
+ self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
76
+ self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
77
+ self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
78
+ self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
79
+ self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
80
+ self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
81
+ self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
82
+ self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
83
+ self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
84
+ self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
85
+ self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
86
+ self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
87
+ self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
88
+ self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
89
+ self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
90
+ self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
91
+ self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
92
+ self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
93
+ self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
94
+ self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
95
+ self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
96
+ self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
97
+ self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
98
+ self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
99
+ self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
100
+ self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
101
+ self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
102
+ self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
103
+ self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
104
+ self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
105
+ self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
106
+ self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
107
+ self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
108
+ self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
109
+ self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
110
+ self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
111
+ self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
112
+ self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
113
+ self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
114
+ self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
115
+ self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
116
+ self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
117
+ self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
118
+ self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
119
+ self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
120
+ self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
121
+ self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
122
+ self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
123
+ self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
124
+ self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
125
+ self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
126
+ self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
127
+ self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
128
+ self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
129
+ self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
130
+ self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
131
+ self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
132
+ self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
133
+ self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
134
+ self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
135
+ self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
136
+ self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
137
+ self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
138
+ self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
139
+ self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
140
+ self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
141
+ self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
142
+ self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
143
+ self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
144
+ self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
145
+ self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
146
+ self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
147
+ self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
148
+ self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
149
+ self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
150
+ self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
151
+ self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
152
+ self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
153
+ self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
154
+ self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
155
+ self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
156
+ self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
157
+ self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
158
+ self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
159
+ self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
160
+ self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
161
+ self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
162
+ self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
163
+ self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
164
+ self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
165
+ self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
166
+ self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
167
+ self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
168
+ self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
169
+ self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
170
+ self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
171
+ self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
172
+ self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
173
+ self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
174
+ self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
175
+ self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
176
+ self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
177
+ self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
178
+ self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
179
+ self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
180
+ self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
181
+ self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
182
+ self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
183
+ self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
184
+ self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
185
+ self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
186
+ self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
187
+ self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
188
+ self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
189
+ self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
190
+ self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
191
+ self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
192
+ self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
193
+ self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
194
+ self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
195
+ self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
196
+
197
+ def forward(self, *inputs):
198
+ t_358, = inputs
199
+ t_359 = t_358.permute(*[0, 3, 1, 2])
200
+ t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
201
+ t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
202
+ t_361 = F.relu(t_360)
203
+ t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
204
+ t_362 = self.n_MaxPool_0(t_361)
205
+ t_363 = self.n_Conv_1(t_362)
206
+ t_364 = self.n_Conv_2(t_362)
207
+ t_365 = F.relu(t_364)
208
+ t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
209
+ t_366 = self.n_Conv_3(t_365_padded)
210
+ t_367 = F.relu(t_366)
211
+ t_368 = self.n_Conv_4(t_367)
212
+ t_369 = torch.add(t_368, t_363)
213
+ t_370 = F.relu(t_369)
214
+ t_371 = self.n_Conv_5(t_370)
215
+ t_372 = F.relu(t_371)
216
+ t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
217
+ t_373 = self.n_Conv_6(t_372_padded)
218
+ t_374 = F.relu(t_373)
219
+ t_375 = self.n_Conv_7(t_374)
220
+ t_376 = torch.add(t_375, t_370)
221
+ t_377 = F.relu(t_376)
222
+ t_378 = self.n_Conv_8(t_377)
223
+ t_379 = F.relu(t_378)
224
+ t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
225
+ t_380 = self.n_Conv_9(t_379_padded)
226
+ t_381 = F.relu(t_380)
227
+ t_382 = self.n_Conv_10(t_381)
228
+ t_383 = torch.add(t_382, t_377)
229
+ t_384 = F.relu(t_383)
230
+ t_385 = self.n_Conv_11(t_384)
231
+ t_386 = self.n_Conv_12(t_384)
232
+ t_387 = F.relu(t_386)
233
+ t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
234
+ t_388 = self.n_Conv_13(t_387_padded)
235
+ t_389 = F.relu(t_388)
236
+ t_390 = self.n_Conv_14(t_389)
237
+ t_391 = torch.add(t_390, t_385)
238
+ t_392 = F.relu(t_391)
239
+ t_393 = self.n_Conv_15(t_392)
240
+ t_394 = F.relu(t_393)
241
+ t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
242
+ t_395 = self.n_Conv_16(t_394_padded)
243
+ t_396 = F.relu(t_395)
244
+ t_397 = self.n_Conv_17(t_396)
245
+ t_398 = torch.add(t_397, t_392)
246
+ t_399 = F.relu(t_398)
247
+ t_400 = self.n_Conv_18(t_399)
248
+ t_401 = F.relu(t_400)
249
+ t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
250
+ t_402 = self.n_Conv_19(t_401_padded)
251
+ t_403 = F.relu(t_402)
252
+ t_404 = self.n_Conv_20(t_403)
253
+ t_405 = torch.add(t_404, t_399)
254
+ t_406 = F.relu(t_405)
255
+ t_407 = self.n_Conv_21(t_406)
256
+ t_408 = F.relu(t_407)
257
+ t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
258
+ t_409 = self.n_Conv_22(t_408_padded)
259
+ t_410 = F.relu(t_409)
260
+ t_411 = self.n_Conv_23(t_410)
261
+ t_412 = torch.add(t_411, t_406)
262
+ t_413 = F.relu(t_412)
263
+ t_414 = self.n_Conv_24(t_413)
264
+ t_415 = F.relu(t_414)
265
+ t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
266
+ t_416 = self.n_Conv_25(t_415_padded)
267
+ t_417 = F.relu(t_416)
268
+ t_418 = self.n_Conv_26(t_417)
269
+ t_419 = torch.add(t_418, t_413)
270
+ t_420 = F.relu(t_419)
271
+ t_421 = self.n_Conv_27(t_420)
272
+ t_422 = F.relu(t_421)
273
+ t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
274
+ t_423 = self.n_Conv_28(t_422_padded)
275
+ t_424 = F.relu(t_423)
276
+ t_425 = self.n_Conv_29(t_424)
277
+ t_426 = torch.add(t_425, t_420)
278
+ t_427 = F.relu(t_426)
279
+ t_428 = self.n_Conv_30(t_427)
280
+ t_429 = F.relu(t_428)
281
+ t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
282
+ t_430 = self.n_Conv_31(t_429_padded)
283
+ t_431 = F.relu(t_430)
284
+ t_432 = self.n_Conv_32(t_431)
285
+ t_433 = torch.add(t_432, t_427)
286
+ t_434 = F.relu(t_433)
287
+ t_435 = self.n_Conv_33(t_434)
288
+ t_436 = F.relu(t_435)
289
+ t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
290
+ t_437 = self.n_Conv_34(t_436_padded)
291
+ t_438 = F.relu(t_437)
292
+ t_439 = self.n_Conv_35(t_438)
293
+ t_440 = torch.add(t_439, t_434)
294
+ t_441 = F.relu(t_440)
295
+ t_442 = self.n_Conv_36(t_441)
296
+ t_443 = self.n_Conv_37(t_441)
297
+ t_444 = F.relu(t_443)
298
+ t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
299
+ t_445 = self.n_Conv_38(t_444_padded)
300
+ t_446 = F.relu(t_445)
301
+ t_447 = self.n_Conv_39(t_446)
302
+ t_448 = torch.add(t_447, t_442)
303
+ t_449 = F.relu(t_448)
304
+ t_450 = self.n_Conv_40(t_449)
305
+ t_451 = F.relu(t_450)
306
+ t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
307
+ t_452 = self.n_Conv_41(t_451_padded)
308
+ t_453 = F.relu(t_452)
309
+ t_454 = self.n_Conv_42(t_453)
310
+ t_455 = torch.add(t_454, t_449)
311
+ t_456 = F.relu(t_455)
312
+ t_457 = self.n_Conv_43(t_456)
313
+ t_458 = F.relu(t_457)
314
+ t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
315
+ t_459 = self.n_Conv_44(t_458_padded)
316
+ t_460 = F.relu(t_459)
317
+ t_461 = self.n_Conv_45(t_460)
318
+ t_462 = torch.add(t_461, t_456)
319
+ t_463 = F.relu(t_462)
320
+ t_464 = self.n_Conv_46(t_463)
321
+ t_465 = F.relu(t_464)
322
+ t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
323
+ t_466 = self.n_Conv_47(t_465_padded)
324
+ t_467 = F.relu(t_466)
325
+ t_468 = self.n_Conv_48(t_467)
326
+ t_469 = torch.add(t_468, t_463)
327
+ t_470 = F.relu(t_469)
328
+ t_471 = self.n_Conv_49(t_470)
329
+ t_472 = F.relu(t_471)
330
+ t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
331
+ t_473 = self.n_Conv_50(t_472_padded)
332
+ t_474 = F.relu(t_473)
333
+ t_475 = self.n_Conv_51(t_474)
334
+ t_476 = torch.add(t_475, t_470)
335
+ t_477 = F.relu(t_476)
336
+ t_478 = self.n_Conv_52(t_477)
337
+ t_479 = F.relu(t_478)
338
+ t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
339
+ t_480 = self.n_Conv_53(t_479_padded)
340
+ t_481 = F.relu(t_480)
341
+ t_482 = self.n_Conv_54(t_481)
342
+ t_483 = torch.add(t_482, t_477)
343
+ t_484 = F.relu(t_483)
344
+ t_485 = self.n_Conv_55(t_484)
345
+ t_486 = F.relu(t_485)
346
+ t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
347
+ t_487 = self.n_Conv_56(t_486_padded)
348
+ t_488 = F.relu(t_487)
349
+ t_489 = self.n_Conv_57(t_488)
350
+ t_490 = torch.add(t_489, t_484)
351
+ t_491 = F.relu(t_490)
352
+ t_492 = self.n_Conv_58(t_491)
353
+ t_493 = F.relu(t_492)
354
+ t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
355
+ t_494 = self.n_Conv_59(t_493_padded)
356
+ t_495 = F.relu(t_494)
357
+ t_496 = self.n_Conv_60(t_495)
358
+ t_497 = torch.add(t_496, t_491)
359
+ t_498 = F.relu(t_497)
360
+ t_499 = self.n_Conv_61(t_498)
361
+ t_500 = F.relu(t_499)
362
+ t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
363
+ t_501 = self.n_Conv_62(t_500_padded)
364
+ t_502 = F.relu(t_501)
365
+ t_503 = self.n_Conv_63(t_502)
366
+ t_504 = torch.add(t_503, t_498)
367
+ t_505 = F.relu(t_504)
368
+ t_506 = self.n_Conv_64(t_505)
369
+ t_507 = F.relu(t_506)
370
+ t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
371
+ t_508 = self.n_Conv_65(t_507_padded)
372
+ t_509 = F.relu(t_508)
373
+ t_510 = self.n_Conv_66(t_509)
374
+ t_511 = torch.add(t_510, t_505)
375
+ t_512 = F.relu(t_511)
376
+ t_513 = self.n_Conv_67(t_512)
377
+ t_514 = F.relu(t_513)
378
+ t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
379
+ t_515 = self.n_Conv_68(t_514_padded)
380
+ t_516 = F.relu(t_515)
381
+ t_517 = self.n_Conv_69(t_516)
382
+ t_518 = torch.add(t_517, t_512)
383
+ t_519 = F.relu(t_518)
384
+ t_520 = self.n_Conv_70(t_519)
385
+ t_521 = F.relu(t_520)
386
+ t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
387
+ t_522 = self.n_Conv_71(t_521_padded)
388
+ t_523 = F.relu(t_522)
389
+ t_524 = self.n_Conv_72(t_523)
390
+ t_525 = torch.add(t_524, t_519)
391
+ t_526 = F.relu(t_525)
392
+ t_527 = self.n_Conv_73(t_526)
393
+ t_528 = F.relu(t_527)
394
+ t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
395
+ t_529 = self.n_Conv_74(t_528_padded)
396
+ t_530 = F.relu(t_529)
397
+ t_531 = self.n_Conv_75(t_530)
398
+ t_532 = torch.add(t_531, t_526)
399
+ t_533 = F.relu(t_532)
400
+ t_534 = self.n_Conv_76(t_533)
401
+ t_535 = F.relu(t_534)
402
+ t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
403
+ t_536 = self.n_Conv_77(t_535_padded)
404
+ t_537 = F.relu(t_536)
405
+ t_538 = self.n_Conv_78(t_537)
406
+ t_539 = torch.add(t_538, t_533)
407
+ t_540 = F.relu(t_539)
408
+ t_541 = self.n_Conv_79(t_540)
409
+ t_542 = F.relu(t_541)
410
+ t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
411
+ t_543 = self.n_Conv_80(t_542_padded)
412
+ t_544 = F.relu(t_543)
413
+ t_545 = self.n_Conv_81(t_544)
414
+ t_546 = torch.add(t_545, t_540)
415
+ t_547 = F.relu(t_546)
416
+ t_548 = self.n_Conv_82(t_547)
417
+ t_549 = F.relu(t_548)
418
+ t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
419
+ t_550 = self.n_Conv_83(t_549_padded)
420
+ t_551 = F.relu(t_550)
421
+ t_552 = self.n_Conv_84(t_551)
422
+ t_553 = torch.add(t_552, t_547)
423
+ t_554 = F.relu(t_553)
424
+ t_555 = self.n_Conv_85(t_554)
425
+ t_556 = F.relu(t_555)
426
+ t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
427
+ t_557 = self.n_Conv_86(t_556_padded)
428
+ t_558 = F.relu(t_557)
429
+ t_559 = self.n_Conv_87(t_558)
430
+ t_560 = torch.add(t_559, t_554)
431
+ t_561 = F.relu(t_560)
432
+ t_562 = self.n_Conv_88(t_561)
433
+ t_563 = F.relu(t_562)
434
+ t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
435
+ t_564 = self.n_Conv_89(t_563_padded)
436
+ t_565 = F.relu(t_564)
437
+ t_566 = self.n_Conv_90(t_565)
438
+ t_567 = torch.add(t_566, t_561)
439
+ t_568 = F.relu(t_567)
440
+ t_569 = self.n_Conv_91(t_568)
441
+ t_570 = F.relu(t_569)
442
+ t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
443
+ t_571 = self.n_Conv_92(t_570_padded)
444
+ t_572 = F.relu(t_571)
445
+ t_573 = self.n_Conv_93(t_572)
446
+ t_574 = torch.add(t_573, t_568)
447
+ t_575 = F.relu(t_574)
448
+ t_576 = self.n_Conv_94(t_575)
449
+ t_577 = F.relu(t_576)
450
+ t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
451
+ t_578 = self.n_Conv_95(t_577_padded)
452
+ t_579 = F.relu(t_578)
453
+ t_580 = self.n_Conv_96(t_579)
454
+ t_581 = torch.add(t_580, t_575)
455
+ t_582 = F.relu(t_581)
456
+ t_583 = self.n_Conv_97(t_582)
457
+ t_584 = F.relu(t_583)
458
+ t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
459
+ t_585 = self.n_Conv_98(t_584_padded)
460
+ t_586 = F.relu(t_585)
461
+ t_587 = self.n_Conv_99(t_586)
462
+ t_588 = self.n_Conv_100(t_582)
463
+ t_589 = torch.add(t_587, t_588)
464
+ t_590 = F.relu(t_589)
465
+ t_591 = self.n_Conv_101(t_590)
466
+ t_592 = F.relu(t_591)
467
+ t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
468
+ t_593 = self.n_Conv_102(t_592_padded)
469
+ t_594 = F.relu(t_593)
470
+ t_595 = self.n_Conv_103(t_594)
471
+ t_596 = torch.add(t_595, t_590)
472
+ t_597 = F.relu(t_596)
473
+ t_598 = self.n_Conv_104(t_597)
474
+ t_599 = F.relu(t_598)
475
+ t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
476
+ t_600 = self.n_Conv_105(t_599_padded)
477
+ t_601 = F.relu(t_600)
478
+ t_602 = self.n_Conv_106(t_601)
479
+ t_603 = torch.add(t_602, t_597)
480
+ t_604 = F.relu(t_603)
481
+ t_605 = self.n_Conv_107(t_604)
482
+ t_606 = F.relu(t_605)
483
+ t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
484
+ t_607 = self.n_Conv_108(t_606_padded)
485
+ t_608 = F.relu(t_607)
486
+ t_609 = self.n_Conv_109(t_608)
487
+ t_610 = torch.add(t_609, t_604)
488
+ t_611 = F.relu(t_610)
489
+ t_612 = self.n_Conv_110(t_611)
490
+ t_613 = F.relu(t_612)
491
+ t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
492
+ t_614 = self.n_Conv_111(t_613_padded)
493
+ t_615 = F.relu(t_614)
494
+ t_616 = self.n_Conv_112(t_615)
495
+ t_617 = torch.add(t_616, t_611)
496
+ t_618 = F.relu(t_617)
497
+ t_619 = self.n_Conv_113(t_618)
498
+ t_620 = F.relu(t_619)
499
+ t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
500
+ t_621 = self.n_Conv_114(t_620_padded)
501
+ t_622 = F.relu(t_621)
502
+ t_623 = self.n_Conv_115(t_622)
503
+ t_624 = torch.add(t_623, t_618)
504
+ t_625 = F.relu(t_624)
505
+ t_626 = self.n_Conv_116(t_625)
506
+ t_627 = F.relu(t_626)
507
+ t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
508
+ t_628 = self.n_Conv_117(t_627_padded)
509
+ t_629 = F.relu(t_628)
510
+ t_630 = self.n_Conv_118(t_629)
511
+ t_631 = torch.add(t_630, t_625)
512
+ t_632 = F.relu(t_631)
513
+ t_633 = self.n_Conv_119(t_632)
514
+ t_634 = F.relu(t_633)
515
+ t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
516
+ t_635 = self.n_Conv_120(t_634_padded)
517
+ t_636 = F.relu(t_635)
518
+ t_637 = self.n_Conv_121(t_636)
519
+ t_638 = torch.add(t_637, t_632)
520
+ t_639 = F.relu(t_638)
521
+ t_640 = self.n_Conv_122(t_639)
522
+ t_641 = F.relu(t_640)
523
+ t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
524
+ t_642 = self.n_Conv_123(t_641_padded)
525
+ t_643 = F.relu(t_642)
526
+ t_644 = self.n_Conv_124(t_643)
527
+ t_645 = torch.add(t_644, t_639)
528
+ t_646 = F.relu(t_645)
529
+ t_647 = self.n_Conv_125(t_646)
530
+ t_648 = F.relu(t_647)
531
+ t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
532
+ t_649 = self.n_Conv_126(t_648_padded)
533
+ t_650 = F.relu(t_649)
534
+ t_651 = self.n_Conv_127(t_650)
535
+ t_652 = torch.add(t_651, t_646)
536
+ t_653 = F.relu(t_652)
537
+ t_654 = self.n_Conv_128(t_653)
538
+ t_655 = F.relu(t_654)
539
+ t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
540
+ t_656 = self.n_Conv_129(t_655_padded)
541
+ t_657 = F.relu(t_656)
542
+ t_658 = self.n_Conv_130(t_657)
543
+ t_659 = torch.add(t_658, t_653)
544
+ t_660 = F.relu(t_659)
545
+ t_661 = self.n_Conv_131(t_660)
546
+ t_662 = F.relu(t_661)
547
+ t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
548
+ t_663 = self.n_Conv_132(t_662_padded)
549
+ t_664 = F.relu(t_663)
550
+ t_665 = self.n_Conv_133(t_664)
551
+ t_666 = torch.add(t_665, t_660)
552
+ t_667 = F.relu(t_666)
553
+ t_668 = self.n_Conv_134(t_667)
554
+ t_669 = F.relu(t_668)
555
+ t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
556
+ t_670 = self.n_Conv_135(t_669_padded)
557
+ t_671 = F.relu(t_670)
558
+ t_672 = self.n_Conv_136(t_671)
559
+ t_673 = torch.add(t_672, t_667)
560
+ t_674 = F.relu(t_673)
561
+ t_675 = self.n_Conv_137(t_674)
562
+ t_676 = F.relu(t_675)
563
+ t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
564
+ t_677 = self.n_Conv_138(t_676_padded)
565
+ t_678 = F.relu(t_677)
566
+ t_679 = self.n_Conv_139(t_678)
567
+ t_680 = torch.add(t_679, t_674)
568
+ t_681 = F.relu(t_680)
569
+ t_682 = self.n_Conv_140(t_681)
570
+ t_683 = F.relu(t_682)
571
+ t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
572
+ t_684 = self.n_Conv_141(t_683_padded)
573
+ t_685 = F.relu(t_684)
574
+ t_686 = self.n_Conv_142(t_685)
575
+ t_687 = torch.add(t_686, t_681)
576
+ t_688 = F.relu(t_687)
577
+ t_689 = self.n_Conv_143(t_688)
578
+ t_690 = F.relu(t_689)
579
+ t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
580
+ t_691 = self.n_Conv_144(t_690_padded)
581
+ t_692 = F.relu(t_691)
582
+ t_693 = self.n_Conv_145(t_692)
583
+ t_694 = torch.add(t_693, t_688)
584
+ t_695 = F.relu(t_694)
585
+ t_696 = self.n_Conv_146(t_695)
586
+ t_697 = F.relu(t_696)
587
+ t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
588
+ t_698 = self.n_Conv_147(t_697_padded)
589
+ t_699 = F.relu(t_698)
590
+ t_700 = self.n_Conv_148(t_699)
591
+ t_701 = torch.add(t_700, t_695)
592
+ t_702 = F.relu(t_701)
593
+ t_703 = self.n_Conv_149(t_702)
594
+ t_704 = F.relu(t_703)
595
+ t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
596
+ t_705 = self.n_Conv_150(t_704_padded)
597
+ t_706 = F.relu(t_705)
598
+ t_707 = self.n_Conv_151(t_706)
599
+ t_708 = torch.add(t_707, t_702)
600
+ t_709 = F.relu(t_708)
601
+ t_710 = self.n_Conv_152(t_709)
602
+ t_711 = F.relu(t_710)
603
+ t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
604
+ t_712 = self.n_Conv_153(t_711_padded)
605
+ t_713 = F.relu(t_712)
606
+ t_714 = self.n_Conv_154(t_713)
607
+ t_715 = torch.add(t_714, t_709)
608
+ t_716 = F.relu(t_715)
609
+ t_717 = self.n_Conv_155(t_716)
610
+ t_718 = F.relu(t_717)
611
+ t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
612
+ t_719 = self.n_Conv_156(t_718_padded)
613
+ t_720 = F.relu(t_719)
614
+ t_721 = self.n_Conv_157(t_720)
615
+ t_722 = torch.add(t_721, t_716)
616
+ t_723 = F.relu(t_722)
617
+ t_724 = self.n_Conv_158(t_723)
618
+ t_725 = self.n_Conv_159(t_723)
619
+ t_726 = F.relu(t_725)
620
+ t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
621
+ t_727 = self.n_Conv_160(t_726_padded)
622
+ t_728 = F.relu(t_727)
623
+ t_729 = self.n_Conv_161(t_728)
624
+ t_730 = torch.add(t_729, t_724)
625
+ t_731 = F.relu(t_730)
626
+ t_732 = self.n_Conv_162(t_731)
627
+ t_733 = F.relu(t_732)
628
+ t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
629
+ t_734 = self.n_Conv_163(t_733_padded)
630
+ t_735 = F.relu(t_734)
631
+ t_736 = self.n_Conv_164(t_735)
632
+ t_737 = torch.add(t_736, t_731)
633
+ t_738 = F.relu(t_737)
634
+ t_739 = self.n_Conv_165(t_738)
635
+ t_740 = F.relu(t_739)
636
+ t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
637
+ t_741 = self.n_Conv_166(t_740_padded)
638
+ t_742 = F.relu(t_741)
639
+ t_743 = self.n_Conv_167(t_742)
640
+ t_744 = torch.add(t_743, t_738)
641
+ t_745 = F.relu(t_744)
642
+ t_746 = self.n_Conv_168(t_745)
643
+ t_747 = self.n_Conv_169(t_745)
644
+ t_748 = F.relu(t_747)
645
+ t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
646
+ t_749 = self.n_Conv_170(t_748_padded)
647
+ t_750 = F.relu(t_749)
648
+ t_751 = self.n_Conv_171(t_750)
649
+ t_752 = torch.add(t_751, t_746)
650
+ t_753 = F.relu(t_752)
651
+ t_754 = self.n_Conv_172(t_753)
652
+ t_755 = F.relu(t_754)
653
+ t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
654
+ t_756 = self.n_Conv_173(t_755_padded)
655
+ t_757 = F.relu(t_756)
656
+ t_758 = self.n_Conv_174(t_757)
657
+ t_759 = torch.add(t_758, t_753)
658
+ t_760 = F.relu(t_759)
659
+ t_761 = self.n_Conv_175(t_760)
660
+ t_762 = F.relu(t_761)
661
+ t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
662
+ t_763 = self.n_Conv_176(t_762_padded)
663
+ t_764 = F.relu(t_763)
664
+ t_765 = self.n_Conv_177(t_764)
665
+ t_766 = torch.add(t_765, t_760)
666
+ t_767 = F.relu(t_766)
667
+ t_768 = self.n_Conv_178(t_767)
668
+ t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
669
+ t_770 = torch.squeeze(t_769, 3)
670
+ t_770 = torch.squeeze(t_770, 2)
671
+ t_771 = torch.sigmoid(t_770)
672
+ return t_771
673
+
674
+ def load_state_dict(self, state_dict, **kwargs):
675
+ self.tags = state_dict.get('tags', [])
676
+
677
+ super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
678
+
modules/devices.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import contextlib
3
+ from functools import lru_cache
4
+
5
+ import torch
6
+ from modules import errors
7
+
8
+ if sys.platform == "darwin":
9
+ from modules import mac_specific
10
+
11
+
12
+ def has_mps() -> bool:
13
+ if sys.platform != "darwin":
14
+ return False
15
+ else:
16
+ return mac_specific.has_mps
17
+
18
+
19
+ def get_cuda_device_string():
20
+ from modules import shared
21
+
22
+ if shared.cmd_opts.device_id is not None:
23
+ return f"cuda:{shared.cmd_opts.device_id}"
24
+
25
+ return "cuda"
26
+
27
+
28
+ def get_optimal_device_name():
29
+ if torch.cuda.is_available():
30
+ return get_cuda_device_string()
31
+
32
+ if has_mps():
33
+ return "mps"
34
+
35
+ return "cpu"
36
+
37
+
38
+ def get_optimal_device():
39
+ return torch.device(get_optimal_device_name())
40
+
41
+
42
+ def get_device_for(task):
43
+ from modules import shared
44
+
45
+ if task in shared.cmd_opts.use_cpu:
46
+ return cpu
47
+
48
+ return get_optimal_device()
49
+
50
+
51
+ def torch_gc():
52
+
53
+ if torch.cuda.is_available():
54
+ with torch.cuda.device(get_cuda_device_string()):
55
+ torch.cuda.empty_cache()
56
+ torch.cuda.ipc_collect()
57
+
58
+ if has_mps():
59
+ mac_specific.torch_mps_gc()
60
+
61
+
62
+ def enable_tf32():
63
+ if torch.cuda.is_available():
64
+
65
+ # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
66
+ # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
67
+ if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
68
+ torch.backends.cudnn.benchmark = True
69
+
70
+ torch.backends.cuda.matmul.allow_tf32 = True
71
+ torch.backends.cudnn.allow_tf32 = True
72
+
73
+
74
+
75
+ errors.run(enable_tf32, "Enabling TF32")
76
+
77
+ cpu = torch.device("cpu")
78
+ device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
79
+ dtype = torch.float16
80
+ dtype_vae = torch.float16
81
+ dtype_unet = torch.float16
82
+ unet_needs_upcast = False
83
+
84
+
85
+ def cond_cast_unet(input):
86
+ return input.to(dtype_unet) if unet_needs_upcast else input
87
+
88
+
89
+ def cond_cast_float(input):
90
+ return input.float() if unet_needs_upcast else input
91
+
92
+
93
+ def randn(seed, shape):
94
+ from modules.shared import opts
95
+
96
+ torch.manual_seed(seed)
97
+ if opts.randn_source == "CPU" or device.type == 'mps':
98
+ return torch.randn(shape, device=cpu).to(device)
99
+ return torch.randn(shape, device=device)
100
+
101
+
102
+ def randn_without_seed(shape):
103
+ from modules.shared import opts
104
+
105
+ if opts.randn_source == "CPU" or device.type == 'mps':
106
+ return torch.randn(shape, device=cpu).to(device)
107
+ return torch.randn(shape, device=device)
108
+
109
+
110
+ def autocast(disable=False):
111
+ from modules import shared
112
+
113
+ if disable:
114
+ return contextlib.nullcontext()
115
+
116
+ if dtype == torch.float32 or shared.cmd_opts.precision == "full":
117
+ return contextlib.nullcontext()
118
+
119
+ return torch.autocast("cuda")
120
+
121
+
122
+ def without_autocast(disable=False):
123
+ return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
124
+
125
+
126
+ class NansException(Exception):
127
+ pass
128
+
129
+
130
+ def test_for_nans(x, where):
131
+ from modules import shared
132
+
133
+ if shared.cmd_opts.disable_nan_check:
134
+ return
135
+
136
+ if not torch.all(torch.isnan(x)).item():
137
+ return
138
+
139
+ if where == "unet":
140
+ message = "A tensor with all NaNs was produced in Unet."
141
+
142
+ if not shared.cmd_opts.no_half:
143
+ message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
144
+
145
+ elif where == "vae":
146
+ message = "A tensor with all NaNs was produced in VAE."
147
+
148
+ if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
149
+ message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
150
+ else:
151
+ message = "A tensor with all NaNs was produced."
152
+
153
+ message += " Use --disable-nan-check commandline argument to disable this check."
154
+
155
+ raise NansException(message)
156
+
157
+
158
+ @lru_cache
159
+ def first_time_calculation():
160
+ """
161
+ just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
162
+ spends about 2.7 seconds doing that, at least wih NVidia.
163
+ """
164
+
165
+ x = torch.zeros((1, 1)).to(device, dtype)
166
+ linear = torch.nn.Linear(1, 1).to(device, dtype)
167
+ linear(x)
168
+
169
+ x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
170
+ conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
171
+ conv2d(x)
modules/errors.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import textwrap
3
+ import traceback
4
+
5
+
6
+ exception_records = []
7
+
8
+
9
+ def record_exception():
10
+ _, e, tb = sys.exc_info()
11
+ if e is None:
12
+ return
13
+
14
+ if exception_records and exception_records[-1] == e:
15
+ return
16
+
17
+ exception_records.append((e, tb))
18
+
19
+ if len(exception_records) > 5:
20
+ exception_records.pop(0)
21
+
22
+
23
+ def report(message: str, *, exc_info: bool = False) -> None:
24
+ """
25
+ Print an error message to stderr, with optional traceback.
26
+ """
27
+
28
+ record_exception()
29
+
30
+ for line in message.splitlines():
31
+ print("***", line, file=sys.stderr)
32
+ if exc_info:
33
+ print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
34
+ print("---", file=sys.stderr)
35
+
36
+
37
+ def print_error_explanation(message):
38
+ record_exception()
39
+
40
+ lines = message.strip().split("\n")
41
+ max_len = max([len(x) for x in lines])
42
+
43
+ print('=' * max_len, file=sys.stderr)
44
+ for line in lines:
45
+ print(line, file=sys.stderr)
46
+ print('=' * max_len, file=sys.stderr)
47
+
48
+
49
+ def display(e: Exception, task, *, full_traceback=False):
50
+ record_exception()
51
+
52
+ print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
53
+ te = traceback.TracebackException.from_exception(e)
54
+ if full_traceback:
55
+ # include frames leading up to the try-catch block
56
+ te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
57
+ print(*te.format(), sep="", file=sys.stderr)
58
+
59
+ message = str(e)
60
+ if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
61
+ print_error_explanation("""
62
+ The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
63
+ See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
64
+ """)
65
+
66
+
67
+ already_displayed = {}
68
+
69
+
70
+ def display_once(e: Exception, task):
71
+ record_exception()
72
+
73
+ if task in already_displayed:
74
+ return
75
+
76
+ display(e, task)
77
+
78
+ already_displayed[task] = 1
79
+
80
+
81
+ def run(code, task):
82
+ try:
83
+ code()
84
+ except Exception as e:
85
+ display(task, e)
modules/esrgan_model.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+
7
+ import modules.esrgan_model_arch as arch
8
+ from modules import modelloader, images, devices
9
+ from modules.shared import opts
10
+ from modules.upscaler import Upscaler, UpscalerData
11
+
12
+
13
+ def mod2normal(state_dict):
14
+ # this code is copied from https://github.com/victorca25/iNNfer
15
+ if 'conv_first.weight' in state_dict:
16
+ crt_net = {}
17
+ items = list(state_dict)
18
+
19
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
20
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
21
+
22
+ for k in items.copy():
23
+ if 'RDB' in k:
24
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
25
+ if '.weight' in k:
26
+ ori_k = ori_k.replace('.weight', '.0.weight')
27
+ elif '.bias' in k:
28
+ ori_k = ori_k.replace('.bias', '.0.bias')
29
+ crt_net[ori_k] = state_dict[k]
30
+ items.remove(k)
31
+
32
+ crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
33
+ crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
34
+ crt_net['model.3.weight'] = state_dict['upconv1.weight']
35
+ crt_net['model.3.bias'] = state_dict['upconv1.bias']
36
+ crt_net['model.6.weight'] = state_dict['upconv2.weight']
37
+ crt_net['model.6.bias'] = state_dict['upconv2.bias']
38
+ crt_net['model.8.weight'] = state_dict['HRconv.weight']
39
+ crt_net['model.8.bias'] = state_dict['HRconv.bias']
40
+ crt_net['model.10.weight'] = state_dict['conv_last.weight']
41
+ crt_net['model.10.bias'] = state_dict['conv_last.bias']
42
+ state_dict = crt_net
43
+ return state_dict
44
+
45
+
46
+ def resrgan2normal(state_dict, nb=23):
47
+ # this code is copied from https://github.com/victorca25/iNNfer
48
+ if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
49
+ re8x = 0
50
+ crt_net = {}
51
+ items = list(state_dict)
52
+
53
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
54
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
55
+
56
+ for k in items.copy():
57
+ if "rdb" in k:
58
+ ori_k = k.replace('body.', 'model.1.sub.')
59
+ ori_k = ori_k.replace('.rdb', '.RDB')
60
+ if '.weight' in k:
61
+ ori_k = ori_k.replace('.weight', '.0.weight')
62
+ elif '.bias' in k:
63
+ ori_k = ori_k.replace('.bias', '.0.bias')
64
+ crt_net[ori_k] = state_dict[k]
65
+ items.remove(k)
66
+
67
+ crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
68
+ crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
69
+ crt_net['model.3.weight'] = state_dict['conv_up1.weight']
70
+ crt_net['model.3.bias'] = state_dict['conv_up1.bias']
71
+ crt_net['model.6.weight'] = state_dict['conv_up2.weight']
72
+ crt_net['model.6.bias'] = state_dict['conv_up2.bias']
73
+
74
+ if 'conv_up3.weight' in state_dict:
75
+ # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
76
+ re8x = 3
77
+ crt_net['model.9.weight'] = state_dict['conv_up3.weight']
78
+ crt_net['model.9.bias'] = state_dict['conv_up3.bias']
79
+
80
+ crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
81
+ crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
82
+ crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
83
+ crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
84
+
85
+ state_dict = crt_net
86
+ return state_dict
87
+
88
+
89
+ def infer_params(state_dict):
90
+ # this code is copied from https://github.com/victorca25/iNNfer
91
+ scale2x = 0
92
+ scalemin = 6
93
+ n_uplayer = 0
94
+ plus = False
95
+
96
+ for block in list(state_dict):
97
+ parts = block.split(".")
98
+ n_parts = len(parts)
99
+ if n_parts == 5 and parts[2] == "sub":
100
+ nb = int(parts[3])
101
+ elif n_parts == 3:
102
+ part_num = int(parts[1])
103
+ if (part_num > scalemin
104
+ and parts[0] == "model"
105
+ and parts[2] == "weight"):
106
+ scale2x += 1
107
+ if part_num > n_uplayer:
108
+ n_uplayer = part_num
109
+ out_nc = state_dict[block].shape[0]
110
+ if not plus and "conv1x1" in block:
111
+ plus = True
112
+
113
+ nf = state_dict["model.0.weight"].shape[0]
114
+ in_nc = state_dict["model.0.weight"].shape[1]
115
+ out_nc = out_nc
116
+ scale = 2 ** scale2x
117
+
118
+ return in_nc, out_nc, nf, nb, plus, scale
119
+
120
+
121
+ class UpscalerESRGAN(Upscaler):
122
+ def __init__(self, dirname):
123
+ self.name = "ESRGAN"
124
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
125
+ self.model_name = "ESRGAN_4x"
126
+ self.scalers = []
127
+ self.user_path = dirname
128
+ super().__init__()
129
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
130
+ scalers = []
131
+ if len(model_paths) == 0:
132
+ scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
133
+ scalers.append(scaler_data)
134
+ for file in model_paths:
135
+ if file.startswith("http"):
136
+ name = self.model_name
137
+ else:
138
+ name = modelloader.friendly_name(file)
139
+
140
+ scaler_data = UpscalerData(name, file, self, 4)
141
+ self.scalers.append(scaler_data)
142
+
143
+ def do_upscale(self, img, selected_model):
144
+ try:
145
+ model = self.load_model(selected_model)
146
+ except Exception as e:
147
+ print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
148
+ return img
149
+ model.to(devices.device_esrgan)
150
+ img = esrgan_upscale(model, img)
151
+ return img
152
+
153
+ def load_model(self, path: str):
154
+ if path.startswith("http"):
155
+ # TODO: this doesn't use `path` at all?
156
+ filename = modelloader.load_file_from_url(
157
+ url=self.model_url,
158
+ model_dir=self.model_download_path,
159
+ file_name=f"{self.model_name}.pth",
160
+ )
161
+ else:
162
+ filename = path
163
+
164
+ state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
165
+
166
+ if "params_ema" in state_dict:
167
+ state_dict = state_dict["params_ema"]
168
+ elif "params" in state_dict:
169
+ state_dict = state_dict["params"]
170
+ num_conv = 16 if "realesr-animevideov3" in filename else 32
171
+ model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
172
+ model.load_state_dict(state_dict)
173
+ model.eval()
174
+ return model
175
+
176
+ if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
177
+ nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
178
+ state_dict = resrgan2normal(state_dict, nb)
179
+ elif "conv_first.weight" in state_dict:
180
+ state_dict = mod2normal(state_dict)
181
+ elif "model.0.weight" not in state_dict:
182
+ raise Exception("The file is not a recognized ESRGAN model.")
183
+
184
+ in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
185
+
186
+ model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
187
+ model.load_state_dict(state_dict)
188
+ model.eval()
189
+
190
+ return model
191
+
192
+
193
+ def upscale_without_tiling(model, img):
194
+ img = np.array(img)
195
+ img = img[:, :, ::-1]
196
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
197
+ img = torch.from_numpy(img).float()
198
+ img = img.unsqueeze(0).to(devices.device_esrgan)
199
+ with torch.no_grad():
200
+ output = model(img)
201
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
202
+ output = 255. * np.moveaxis(output, 0, 2)
203
+ output = output.astype(np.uint8)
204
+ output = output[:, :, ::-1]
205
+ return Image.fromarray(output, 'RGB')
206
+
207
+
208
+ def esrgan_upscale(model, img):
209
+ if opts.ESRGAN_tile == 0:
210
+ return upscale_without_tiling(model, img)
211
+
212
+ grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
213
+ newtiles = []
214
+ scale_factor = 1
215
+
216
+ for y, h, row in grid.tiles:
217
+ newrow = []
218
+ for tiledata in row:
219
+ x, w, tile = tiledata
220
+
221
+ output = upscale_without_tiling(model, tile)
222
+ scale_factor = output.width // tile.width
223
+
224
+ newrow.append([x * scale_factor, w * scale_factor, output])
225
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
226
+
227
+ newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
228
+ output = images.combine_grid(newgrid)
229
+ return output
modules/esrgan_model_arch.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file is adapted from https://github.com/victorca25/iNNfer
2
+
3
+ from collections import OrderedDict
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ ####################
11
+ # RRDBNet Generator
12
+ ####################
13
+
14
+ class RRDBNet(nn.Module):
15
+ def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
16
+ act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
17
+ finalact=None, gaussian_noise=False, plus=False):
18
+ super(RRDBNet, self).__init__()
19
+ n_upscale = int(math.log(upscale, 2))
20
+ if upscale == 3:
21
+ n_upscale = 1
22
+
23
+ self.resrgan_scale = 0
24
+ if in_nc % 16 == 0:
25
+ self.resrgan_scale = 1
26
+ elif in_nc != 4 and in_nc % 4 == 0:
27
+ self.resrgan_scale = 2
28
+
29
+ fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
30
+ rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
31
+ norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
32
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
33
+ LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
34
+
35
+ if upsample_mode == 'upconv':
36
+ upsample_block = upconv_block
37
+ elif upsample_mode == 'pixelshuffle':
38
+ upsample_block = pixelshuffle_block
39
+ else:
40
+ raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
41
+ if upscale == 3:
42
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
43
+ else:
44
+ upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
45
+ HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
46
+ HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
47
+
48
+ outact = act(finalact) if finalact else None
49
+
50
+ self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
51
+ *upsampler, HR_conv0, HR_conv1, outact)
52
+
53
+ def forward(self, x, outm=None):
54
+ if self.resrgan_scale == 1:
55
+ feat = pixel_unshuffle(x, scale=4)
56
+ elif self.resrgan_scale == 2:
57
+ feat = pixel_unshuffle(x, scale=2)
58
+ else:
59
+ feat = x
60
+
61
+ return self.model(feat)
62
+
63
+
64
+ class RRDB(nn.Module):
65
+ """
66
+ Residual in Residual Dense Block
67
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
68
+ """
69
+
70
+ def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
71
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
72
+ spectral_norm=False, gaussian_noise=False, plus=False):
73
+ super(RRDB, self).__init__()
74
+ # This is for backwards compatibility with existing models
75
+ if nr == 3:
76
+ self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
77
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
78
+ gaussian_noise=gaussian_noise, plus=plus)
79
+ self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
80
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
81
+ gaussian_noise=gaussian_noise, plus=plus)
82
+ self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
83
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
84
+ gaussian_noise=gaussian_noise, plus=plus)
85
+ else:
86
+ RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
87
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
88
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
89
+ self.RDBs = nn.Sequential(*RDB_list)
90
+
91
+ def forward(self, x):
92
+ if hasattr(self, 'RDB1'):
93
+ out = self.RDB1(x)
94
+ out = self.RDB2(out)
95
+ out = self.RDB3(out)
96
+ else:
97
+ out = self.RDBs(x)
98
+ return out * 0.2 + x
99
+
100
+
101
+ class ResidualDenseBlock_5C(nn.Module):
102
+ """
103
+ Residual Dense Block
104
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
105
+ Modified options that can be used:
106
+ - "Partial Convolution based Padding" arXiv:1811.11718
107
+ - "Spectral normalization" arXiv:1802.05957
108
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
109
+ {Rakotonirina} and A. {Rasoanaivo}
110
+ """
111
+
112
+ def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
113
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
114
+ spectral_norm=False, gaussian_noise=False, plus=False):
115
+ super(ResidualDenseBlock_5C, self).__init__()
116
+
117
+ self.noise = GaussianNoise() if gaussian_noise else None
118
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
119
+
120
+ self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
121
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
122
+ spectral_norm=spectral_norm)
123
+ self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
124
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
125
+ spectral_norm=spectral_norm)
126
+ self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
127
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
128
+ spectral_norm=spectral_norm)
129
+ self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
130
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
131
+ spectral_norm=spectral_norm)
132
+ if mode == 'CNA':
133
+ last_act = None
134
+ else:
135
+ last_act = act_type
136
+ self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
137
+ norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
138
+ spectral_norm=spectral_norm)
139
+
140
+ def forward(self, x):
141
+ x1 = self.conv1(x)
142
+ x2 = self.conv2(torch.cat((x, x1), 1))
143
+ if self.conv1x1:
144
+ x2 = x2 + self.conv1x1(x)
145
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
146
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
147
+ if self.conv1x1:
148
+ x4 = x4 + x2
149
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
150
+ if self.noise:
151
+ return self.noise(x5.mul(0.2) + x)
152
+ else:
153
+ return x5 * 0.2 + x
154
+
155
+
156
+ ####################
157
+ # ESRGANplus
158
+ ####################
159
+
160
+ class GaussianNoise(nn.Module):
161
+ def __init__(self, sigma=0.1, is_relative_detach=False):
162
+ super().__init__()
163
+ self.sigma = sigma
164
+ self.is_relative_detach = is_relative_detach
165
+ self.noise = torch.tensor(0, dtype=torch.float)
166
+
167
+ def forward(self, x):
168
+ if self.training and self.sigma != 0:
169
+ self.noise = self.noise.to(x.device)
170
+ scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
171
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
172
+ x = x + sampled_noise
173
+ return x
174
+
175
+ def conv1x1(in_planes, out_planes, stride=1):
176
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
177
+
178
+
179
+ ####################
180
+ # SRVGGNetCompact
181
+ ####################
182
+
183
+ class SRVGGNetCompact(nn.Module):
184
+ """A compact VGG-style network structure for super-resolution.
185
+ This class is copied from https://github.com/xinntao/Real-ESRGAN
186
+ """
187
+
188
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
189
+ super(SRVGGNetCompact, self).__init__()
190
+ self.num_in_ch = num_in_ch
191
+ self.num_out_ch = num_out_ch
192
+ self.num_feat = num_feat
193
+ self.num_conv = num_conv
194
+ self.upscale = upscale
195
+ self.act_type = act_type
196
+
197
+ self.body = nn.ModuleList()
198
+ # the first conv
199
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
200
+ # the first activation
201
+ if act_type == 'relu':
202
+ activation = nn.ReLU(inplace=True)
203
+ elif act_type == 'prelu':
204
+ activation = nn.PReLU(num_parameters=num_feat)
205
+ elif act_type == 'leakyrelu':
206
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
207
+ self.body.append(activation)
208
+
209
+ # the body structure
210
+ for _ in range(num_conv):
211
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
212
+ # activation
213
+ if act_type == 'relu':
214
+ activation = nn.ReLU(inplace=True)
215
+ elif act_type == 'prelu':
216
+ activation = nn.PReLU(num_parameters=num_feat)
217
+ elif act_type == 'leakyrelu':
218
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
219
+ self.body.append(activation)
220
+
221
+ # the last conv
222
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
223
+ # upsample
224
+ self.upsampler = nn.PixelShuffle(upscale)
225
+
226
+ def forward(self, x):
227
+ out = x
228
+ for i in range(0, len(self.body)):
229
+ out = self.body[i](out)
230
+
231
+ out = self.upsampler(out)
232
+ # add the nearest upsampled image, so that the network learns the residual
233
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
234
+ out += base
235
+ return out
236
+
237
+
238
+ ####################
239
+ # Upsampler
240
+ ####################
241
+
242
+ class Upsample(nn.Module):
243
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
244
+ The input data is assumed to be of the form
245
+ `minibatch x channels x [optional depth] x [optional height] x width`.
246
+ """
247
+
248
+ def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
249
+ super(Upsample, self).__init__()
250
+ if isinstance(scale_factor, tuple):
251
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
252
+ else:
253
+ self.scale_factor = float(scale_factor) if scale_factor else None
254
+ self.mode = mode
255
+ self.size = size
256
+ self.align_corners = align_corners
257
+
258
+ def forward(self, x):
259
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
260
+
261
+ def extra_repr(self):
262
+ if self.scale_factor is not None:
263
+ info = f'scale_factor={self.scale_factor}'
264
+ else:
265
+ info = f'size={self.size}'
266
+ info += f', mode={self.mode}'
267
+ return info
268
+
269
+
270
+ def pixel_unshuffle(x, scale):
271
+ """ Pixel unshuffle.
272
+ Args:
273
+ x (Tensor): Input feature with shape (b, c, hh, hw).
274
+ scale (int): Downsample ratio.
275
+ Returns:
276
+ Tensor: the pixel unshuffled feature.
277
+ """
278
+ b, c, hh, hw = x.size()
279
+ out_channel = c * (scale**2)
280
+ assert hh % scale == 0 and hw % scale == 0
281
+ h = hh // scale
282
+ w = hw // scale
283
+ x_view = x.view(b, c, h, scale, w, scale)
284
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
285
+
286
+
287
+ def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
288
+ pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
289
+ """
290
+ Pixel shuffle layer
291
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
292
+ Neural Network, CVPR17)
293
+ """
294
+ conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
295
+ pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
296
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
297
+
298
+ n = norm(norm_type, out_nc) if norm_type else None
299
+ a = act(act_type) if act_type else None
300
+ return sequential(conv, pixel_shuffle, n, a)
301
+
302
+
303
+ def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
304
+ pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
305
+ """ Upconv layer """
306
+ upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
307
+ upsample = Upsample(scale_factor=upscale_factor, mode=mode)
308
+ conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
309
+ pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
310
+ return sequential(upsample, conv)
311
+
312
+
313
+
314
+
315
+
316
+
317
+
318
+
319
+ ####################
320
+ # Basic blocks
321
+ ####################
322
+
323
+
324
+ def make_layer(basic_block, num_basic_block, **kwarg):
325
+ """Make layers by stacking the same blocks.
326
+ Args:
327
+ basic_block (nn.module): nn.module class for basic block. (block)
328
+ num_basic_block (int): number of blocks. (n_layers)
329
+ Returns:
330
+ nn.Sequential: Stacked blocks in nn.Sequential.
331
+ """
332
+ layers = []
333
+ for _ in range(num_basic_block):
334
+ layers.append(basic_block(**kwarg))
335
+ return nn.Sequential(*layers)
336
+
337
+
338
+ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
339
+ """ activation helper """
340
+ act_type = act_type.lower()
341
+ if act_type == 'relu':
342
+ layer = nn.ReLU(inplace)
343
+ elif act_type in ('leakyrelu', 'lrelu'):
344
+ layer = nn.LeakyReLU(neg_slope, inplace)
345
+ elif act_type == 'prelu':
346
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
347
+ elif act_type == 'tanh': # [-1, 1] range output
348
+ layer = nn.Tanh()
349
+ elif act_type == 'sigmoid': # [0, 1] range output
350
+ layer = nn.Sigmoid()
351
+ else:
352
+ raise NotImplementedError(f'activation layer [{act_type}] is not found')
353
+ return layer
354
+
355
+
356
+ class Identity(nn.Module):
357
+ def __init__(self, *kwargs):
358
+ super(Identity, self).__init__()
359
+
360
+ def forward(self, x, *kwargs):
361
+ return x
362
+
363
+
364
+ def norm(norm_type, nc):
365
+ """ Return a normalization layer """
366
+ norm_type = norm_type.lower()
367
+ if norm_type == 'batch':
368
+ layer = nn.BatchNorm2d(nc, affine=True)
369
+ elif norm_type == 'instance':
370
+ layer = nn.InstanceNorm2d(nc, affine=False)
371
+ elif norm_type == 'none':
372
+ def norm_layer(x): return Identity()
373
+ else:
374
+ raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
375
+ return layer
376
+
377
+
378
+ def pad(pad_type, padding):
379
+ """ padding layer helper """
380
+ pad_type = pad_type.lower()
381
+ if padding == 0:
382
+ return None
383
+ if pad_type == 'reflect':
384
+ layer = nn.ReflectionPad2d(padding)
385
+ elif pad_type == 'replicate':
386
+ layer = nn.ReplicationPad2d(padding)
387
+ elif pad_type == 'zero':
388
+ layer = nn.ZeroPad2d(padding)
389
+ else:
390
+ raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
391
+ return layer
392
+
393
+
394
+ def get_valid_padding(kernel_size, dilation):
395
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
396
+ padding = (kernel_size - 1) // 2
397
+ return padding
398
+
399
+
400
+ class ShortcutBlock(nn.Module):
401
+ """ Elementwise sum the output of a submodule to its input """
402
+ def __init__(self, submodule):
403
+ super(ShortcutBlock, self).__init__()
404
+ self.sub = submodule
405
+
406
+ def forward(self, x):
407
+ output = x + self.sub(x)
408
+ return output
409
+
410
+ def __repr__(self):
411
+ return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
412
+
413
+
414
+ def sequential(*args):
415
+ """ Flatten Sequential. It unwraps nn.Sequential. """
416
+ if len(args) == 1:
417
+ if isinstance(args[0], OrderedDict):
418
+ raise NotImplementedError('sequential does not support OrderedDict input.')
419
+ return args[0] # No sequential is needed.
420
+ modules = []
421
+ for module in args:
422
+ if isinstance(module, nn.Sequential):
423
+ for submodule in module.children():
424
+ modules.append(submodule)
425
+ elif isinstance(module, nn.Module):
426
+ modules.append(module)
427
+ return nn.Sequential(*modules)
428
+
429
+
430
+ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
431
+ pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
432
+ spectral_norm=False):
433
+ """ Conv layer with padding, normalization, activation """
434
+ assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
435
+ padding = get_valid_padding(kernel_size, dilation)
436
+ p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
437
+ padding = padding if pad_type == 'zero' else 0
438
+
439
+ if convtype=='PartialConv2D':
440
+ from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
441
+ c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
442
+ dilation=dilation, bias=bias, groups=groups)
443
+ elif convtype=='DeformConv2D':
444
+ from torchvision.ops import DeformConv2d # not tested
445
+ c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
446
+ dilation=dilation, bias=bias, groups=groups)
447
+ elif convtype=='Conv3D':
448
+ c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
449
+ dilation=dilation, bias=bias, groups=groups)
450
+ else:
451
+ c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
452
+ dilation=dilation, bias=bias, groups=groups)
453
+
454
+ if spectral_norm:
455
+ c = nn.utils.spectral_norm(c)
456
+
457
+ a = act(act_type) if act_type else None
458
+ if 'CNA' in mode:
459
+ n = norm(norm_type, out_nc) if norm_type else None
460
+ return sequential(p, c, n, a)
461
+ elif mode == 'NAC':
462
+ if norm_type is None and act_type is not None:
463
+ a = act(act_type, inplace=False)
464
+ n = norm(norm_type, in_nc) if norm_type else None
465
+ return sequential(n, a, p, c)
modules/extensions.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+
4
+ from modules import shared, errors, cache
5
+ from modules.gitpython_hack import Repo
6
+ from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
7
+
8
+ extensions = []
9
+
10
+ os.makedirs(extensions_dir, exist_ok=True)
11
+
12
+
13
+ def active():
14
+ if shared.opts.disable_all_extensions == "all":
15
+ return []
16
+ elif shared.opts.disable_all_extensions == "extra":
17
+ return [x for x in extensions if x.enabled and x.is_builtin]
18
+ else:
19
+ return [x for x in extensions if x.enabled]
20
+
21
+
22
+ class Extension:
23
+ lock = threading.Lock()
24
+ cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
25
+
26
+ def __init__(self, name, path, enabled=True, is_builtin=False):
27
+ self.name = name
28
+ self.path = path
29
+ self.enabled = enabled
30
+ self.status = ''
31
+ self.can_update = False
32
+ self.is_builtin = is_builtin
33
+ self.commit_hash = ''
34
+ self.commit_date = None
35
+ self.version = ''
36
+ self.branch = None
37
+ self.remote = None
38
+ self.have_info_from_repo = False
39
+
40
+ def to_dict(self):
41
+ return {x: getattr(self, x) for x in self.cached_fields}
42
+
43
+ def from_dict(self, d):
44
+ for field in self.cached_fields:
45
+ setattr(self, field, d[field])
46
+
47
+ def read_info_from_repo(self):
48
+ if self.is_builtin or self.have_info_from_repo:
49
+ return
50
+
51
+ def read_from_repo():
52
+ with self.lock:
53
+ if self.have_info_from_repo:
54
+ return
55
+
56
+ self.do_read_info_from_repo()
57
+
58
+ return self.to_dict()
59
+ try:
60
+ d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
61
+ self.from_dict(d)
62
+ except FileNotFoundError:
63
+ pass
64
+ self.status = 'unknown'
65
+
66
+ def do_read_info_from_repo(self):
67
+ repo = None
68
+ try:
69
+ if os.path.exists(os.path.join(self.path, ".git")):
70
+ repo = Repo(self.path)
71
+ except Exception:
72
+ errors.report(f"Error reading github repository info from {self.path}", exc_info=True)
73
+
74
+ if repo is None or repo.bare:
75
+ self.remote = None
76
+ else:
77
+ try:
78
+ self.remote = next(repo.remote().urls, None)
79
+ commit = repo.head.commit
80
+ self.commit_date = commit.committed_date
81
+ if repo.active_branch:
82
+ self.branch = repo.active_branch.name
83
+ self.commit_hash = commit.hexsha
84
+ self.version = self.commit_hash[:8]
85
+
86
+ except Exception:
87
+ errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
88
+ self.remote = None
89
+
90
+ self.have_info_from_repo = True
91
+
92
+ def list_files(self, subdir, extension):
93
+ from modules import scripts
94
+
95
+ dirpath = os.path.join(self.path, subdir)
96
+ if not os.path.isdir(dirpath):
97
+ return []
98
+
99
+ res = []
100
+ for filename in sorted(os.listdir(dirpath)):
101
+ res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
102
+
103
+ res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
104
+
105
+ return res
106
+
107
+ def check_updates(self):
108
+ repo = Repo(self.path)
109
+ for fetch in repo.remote().fetch(dry_run=True):
110
+ if fetch.flags != fetch.HEAD_UPTODATE:
111
+ self.can_update = True
112
+ self.status = "new commits"
113
+ return
114
+
115
+ try:
116
+ origin = repo.rev_parse('origin')
117
+ if repo.head.commit != origin:
118
+ self.can_update = True
119
+ self.status = "behind HEAD"
120
+ return
121
+ except Exception:
122
+ self.can_update = False
123
+ self.status = "unknown (remote error)"
124
+ return
125
+
126
+ self.can_update = False
127
+ self.status = "latest"
128
+
129
+ def fetch_and_reset_hard(self, commit='origin'):
130
+ repo = Repo(self.path)
131
+ # Fix: `error: Your local changes to the following files would be overwritten by merge`,
132
+ # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
133
+ repo.git.fetch(all=True)
134
+ repo.git.reset(commit, hard=True)
135
+ self.have_info_from_repo = False
136
+
137
+
138
+ def list_extensions():
139
+ extensions.clear()
140
+
141
+ if not os.path.isdir(extensions_dir):
142
+ return
143
+
144
+ if shared.opts.disable_all_extensions == "all":
145
+ print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
146
+ elif shared.opts.disable_all_extensions == "extra":
147
+ print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
148
+
149
+ extension_paths = []
150
+ for dirname in [extensions_dir, extensions_builtin_dir]:
151
+ if not os.path.isdir(dirname):
152
+ return
153
+
154
+ for extension_dirname in sorted(os.listdir(dirname)):
155
+ path = os.path.join(dirname, extension_dirname)
156
+ if not os.path.isdir(path):
157
+ continue
158
+
159
+ extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
160
+
161
+ for dirname, path, is_builtin in extension_paths:
162
+ extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
163
+ extensions.append(extension)
modules/extra_networks.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import defaultdict
3
+
4
+ from modules import errors
5
+
6
+ extra_network_registry = {}
7
+ extra_network_aliases = {}
8
+
9
+
10
+ def initialize():
11
+ extra_network_registry.clear()
12
+ extra_network_aliases.clear()
13
+
14
+
15
+ def register_extra_network(extra_network):
16
+ extra_network_registry[extra_network.name] = extra_network
17
+
18
+
19
+ def register_extra_network_alias(extra_network, alias):
20
+ extra_network_aliases[alias] = extra_network
21
+
22
+
23
+ def register_default_extra_networks():
24
+ from modules.extra_networks_hypernet import ExtraNetworkHypernet
25
+ register_extra_network(ExtraNetworkHypernet())
26
+
27
+
28
+ class ExtraNetworkParams:
29
+ def __init__(self, items=None):
30
+ self.items = items or []
31
+ self.positional = []
32
+ self.named = {}
33
+
34
+ for item in self.items:
35
+ parts = item.split('=', 2) if isinstance(item, str) else [item]
36
+ if len(parts) == 2:
37
+ self.named[parts[0]] = parts[1]
38
+ else:
39
+ self.positional.append(item)
40
+
41
+ def __eq__(self, other):
42
+ return self.items == other.items
43
+
44
+
45
+ class ExtraNetwork:
46
+ def __init__(self, name):
47
+ self.name = name
48
+
49
+ def activate(self, p, params_list):
50
+ """
51
+ Called by processing on every run. Whatever the extra network is meant to do should be activated here.
52
+ Passes arguments related to this extra network in params_list.
53
+ User passes arguments by specifying this in his prompt:
54
+
55
+ <name:arg1:arg2:arg3>
56
+
57
+ Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
58
+ separated by colon.
59
+
60
+ Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
61
+ in this case, all effects of this extra networks should be disabled.
62
+
63
+ Can be called multiple times before deactivate() - each new call should override the previous call completely.
64
+
65
+ For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
66
+
67
+ > "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
68
+
69
+ params_list will be:
70
+
71
+ [
72
+ ExtraNetworkParams(items=["agm", "1.1"]),
73
+ ExtraNetworkParams(items=["ray"])
74
+ ]
75
+
76
+ """
77
+ raise NotImplementedError
78
+
79
+ def deactivate(self, p):
80
+ """
81
+ Called at the end of processing for housekeeping. No need to do anything here.
82
+ """
83
+
84
+ raise NotImplementedError
85
+
86
+
87
+ def activate(p, extra_network_data):
88
+ """call activate for extra networks in extra_network_data in specified order, then call
89
+ activate for all remaining registered networks with an empty argument list"""
90
+
91
+ activated = []
92
+
93
+ for extra_network_name, extra_network_args in extra_network_data.items():
94
+ extra_network = extra_network_registry.get(extra_network_name, None)
95
+
96
+ if extra_network is None:
97
+ extra_network = extra_network_aliases.get(extra_network_name, None)
98
+
99
+ if extra_network is None:
100
+ print(f"Skipping unknown extra network: {extra_network_name}")
101
+ continue
102
+
103
+ try:
104
+ extra_network.activate(p, extra_network_args)
105
+ activated.append(extra_network)
106
+ except Exception as e:
107
+ errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
108
+
109
+ for extra_network_name, extra_network in extra_network_registry.items():
110
+ if extra_network in activated:
111
+ continue
112
+
113
+ try:
114
+ extra_network.activate(p, [])
115
+ except Exception as e:
116
+ errors.display(e, f"activating extra network {extra_network_name}")
117
+
118
+ if p.scripts is not None:
119
+ p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
120
+
121
+
122
+ def deactivate(p, extra_network_data):
123
+ """call deactivate for extra networks in extra_network_data in specified order, then call
124
+ deactivate for all remaining registered networks"""
125
+
126
+ for extra_network_name in extra_network_data:
127
+ extra_network = extra_network_registry.get(extra_network_name, None)
128
+ if extra_network is None:
129
+ continue
130
+
131
+ try:
132
+ extra_network.deactivate(p)
133
+ except Exception as e:
134
+ errors.display(e, f"deactivating extra network {extra_network_name}")
135
+
136
+ for extra_network_name, extra_network in extra_network_registry.items():
137
+ args = extra_network_data.get(extra_network_name, None)
138
+ if args is not None:
139
+ continue
140
+
141
+ try:
142
+ extra_network.deactivate(p)
143
+ except Exception as e:
144
+ errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
145
+
146
+
147
+ re_extra_net = re.compile(r"<(\w+):([^>]+)>")
148
+
149
+
150
+ def parse_prompt(prompt):
151
+ res = defaultdict(list)
152
+
153
+ def found(m):
154
+ name = m.group(1)
155
+ args = m.group(2)
156
+
157
+ res[name].append(ExtraNetworkParams(items=args.split(":")))
158
+
159
+ return ""
160
+
161
+ prompt = re.sub(re_extra_net, found, prompt)
162
+
163
+ return prompt, res
164
+
165
+
166
+ def parse_prompts(prompts):
167
+ res = []
168
+ extra_data = None
169
+
170
+ for prompt in prompts:
171
+ updated_prompt, parsed_extra_data = parse_prompt(prompt)
172
+
173
+ if extra_data is None:
174
+ extra_data = parsed_extra_data
175
+
176
+ res.append(updated_prompt)
177
+
178
+ return res, extra_data
179
+
modules/extra_networks_hypernet.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import extra_networks, shared
2
+ from modules.hypernetworks import hypernetwork
3
+
4
+
5
+ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
6
+ def __init__(self):
7
+ super().__init__('hypernet')
8
+
9
+ def activate(self, p, params_list):
10
+ additional = shared.opts.sd_hypernetwork
11
+
12
+ if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional):
13
+ hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
14
+ p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
15
+ params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
16
+
17
+ names = []
18
+ multipliers = []
19
+ for params in params_list:
20
+ assert params.items
21
+
22
+ names.append(params.items[0])
23
+ multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
24
+
25
+ hypernetwork.load_hypernetworks(names, multipliers)
26
+
27
+ def deactivate(self, p):
28
+ pass
modules/extras.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import shutil
4
+ import json
5
+
6
+
7
+ import torch
8
+ import tqdm
9
+
10
+ from modules import shared, images, sd_models, sd_vae, sd_models_config
11
+ from modules.ui_common import plaintext_to_html
12
+ import gradio as gr
13
+ import safetensors.torch
14
+
15
+
16
+ def run_pnginfo(image):
17
+ if image is None:
18
+ return '', '', ''
19
+
20
+ geninfo, items = images.read_info_from_image(image)
21
+ items = {**{'parameters': geninfo}, **items}
22
+
23
+ info = ''
24
+ for key, text in items.items():
25
+ info += f"""
26
+ <div>
27
+ <p><b>{plaintext_to_html(str(key))}</b></p>
28
+ <p>{plaintext_to_html(str(text))}</p>
29
+ </div>
30
+ """.strip()+"\n"
31
+
32
+ if len(info) == 0:
33
+ message = "Nothing found in the image."
34
+ info = f"<div><p>{message}<p></div>"
35
+
36
+ return '', geninfo, info
37
+
38
+
39
+ def create_config(ckpt_result, config_source, a, b, c):
40
+ def config(x):
41
+ res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
42
+ return res if res != shared.sd_default_config else None
43
+
44
+ if config_source == 0:
45
+ cfg = config(a) or config(b) or config(c)
46
+ elif config_source == 1:
47
+ cfg = config(b)
48
+ elif config_source == 2:
49
+ cfg = config(c)
50
+ else:
51
+ cfg = None
52
+
53
+ if cfg is None:
54
+ return
55
+
56
+ filename, _ = os.path.splitext(ckpt_result)
57
+ checkpoint_filename = filename + ".yaml"
58
+
59
+ print("Copying config:")
60
+ print(" from:", cfg)
61
+ print(" to:", checkpoint_filename)
62
+ shutil.copyfile(cfg, checkpoint_filename)
63
+
64
+
65
+ checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
66
+
67
+
68
+ def to_half(tensor, enable):
69
+ if enable and tensor.dtype == torch.float:
70
+ return tensor.half()
71
+
72
+ return tensor
73
+
74
+
75
+ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
76
+ shared.state.begin(job="model-merge")
77
+
78
+ def fail(message):
79
+ shared.state.textinfo = message
80
+ shared.state.end()
81
+ return [*[gr.update() for _ in range(4)], message]
82
+
83
+ def weighted_sum(theta0, theta1, alpha):
84
+ return ((1 - alpha) * theta0) + (alpha * theta1)
85
+
86
+ def get_difference(theta1, theta2):
87
+ return theta1 - theta2
88
+
89
+ def add_difference(theta0, theta1_2_diff, alpha):
90
+ return theta0 + (alpha * theta1_2_diff)
91
+
92
+ def filename_weighted_sum():
93
+ a = primary_model_info.model_name
94
+ b = secondary_model_info.model_name
95
+ Ma = round(1 - multiplier, 2)
96
+ Mb = round(multiplier, 2)
97
+
98
+ return f"{Ma}({a}) + {Mb}({b})"
99
+
100
+ def filename_add_difference():
101
+ a = primary_model_info.model_name
102
+ b = secondary_model_info.model_name
103
+ c = tertiary_model_info.model_name
104
+ M = round(multiplier, 2)
105
+
106
+ return f"{a} + {M}({b} - {c})"
107
+
108
+ def filename_nothing():
109
+ return primary_model_info.model_name
110
+
111
+ theta_funcs = {
112
+ "Weighted sum": (filename_weighted_sum, None, weighted_sum),
113
+ "Add difference": (filename_add_difference, get_difference, add_difference),
114
+ "No interpolation": (filename_nothing, None, None),
115
+ }
116
+ filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
117
+ shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
118
+
119
+ if not primary_model_name:
120
+ return fail("Failed: Merging requires a primary model.")
121
+
122
+ primary_model_info = sd_models.checkpoints_list[primary_model_name]
123
+
124
+ if theta_func2 and not secondary_model_name:
125
+ return fail("Failed: Merging requires a secondary model.")
126
+
127
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
128
+
129
+ if theta_func1 and not tertiary_model_name:
130
+ return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
131
+
132
+ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
133
+
134
+ result_is_inpainting_model = False
135
+ result_is_instruct_pix2pix_model = False
136
+
137
+ if theta_func2:
138
+ shared.state.textinfo = "Loading B"
139
+ print(f"Loading {secondary_model_info.filename}...")
140
+ theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
141
+ else:
142
+ theta_1 = None
143
+
144
+ if theta_func1:
145
+ shared.state.textinfo = "Loading C"
146
+ print(f"Loading {tertiary_model_info.filename}...")
147
+ theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
148
+
149
+ shared.state.textinfo = 'Merging B and C'
150
+ shared.state.sampling_steps = len(theta_1.keys())
151
+ for key in tqdm.tqdm(theta_1.keys()):
152
+ if key in checkpoint_dict_skip_on_merge:
153
+ continue
154
+
155
+ if 'model' in key:
156
+ if key in theta_2:
157
+ t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
158
+ theta_1[key] = theta_func1(theta_1[key], t2)
159
+ else:
160
+ theta_1[key] = torch.zeros_like(theta_1[key])
161
+
162
+ shared.state.sampling_step += 1
163
+ del theta_2
164
+
165
+ shared.state.nextjob()
166
+
167
+ shared.state.textinfo = f"Loading {primary_model_info.filename}..."
168
+ print(f"Loading {primary_model_info.filename}...")
169
+ theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
170
+
171
+ print("Merging...")
172
+ shared.state.textinfo = 'Merging A and B'
173
+ shared.state.sampling_steps = len(theta_0.keys())
174
+ for key in tqdm.tqdm(theta_0.keys()):
175
+ if theta_1 and 'model' in key and key in theta_1:
176
+
177
+ if key in checkpoint_dict_skip_on_merge:
178
+ continue
179
+
180
+ a = theta_0[key]
181
+ b = theta_1[key]
182
+
183
+ # this enables merging an inpainting model (A) with another one (B);
184
+ # where normal model would have 4 channels, for latenst space, inpainting model would
185
+ # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
186
+ if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
187
+ if a.shape[1] == 4 and b.shape[1] == 9:
188
+ raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
189
+ if a.shape[1] == 4 and b.shape[1] == 8:
190
+ raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
191
+
192
+ if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
193
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
194
+ result_is_instruct_pix2pix_model = True
195
+ else:
196
+ assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
197
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
198
+ result_is_inpainting_model = True
199
+ else:
200
+ theta_0[key] = theta_func2(a, b, multiplier)
201
+
202
+ theta_0[key] = to_half(theta_0[key], save_as_half)
203
+
204
+ shared.state.sampling_step += 1
205
+
206
+ del theta_1
207
+
208
+ bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
209
+ if bake_in_vae_filename is not None:
210
+ print(f"Baking in VAE from {bake_in_vae_filename}")
211
+ shared.state.textinfo = 'Baking in VAE'
212
+ vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
213
+
214
+ for key in vae_dict.keys():
215
+ theta_0_key = 'first_stage_model.' + key
216
+ if theta_0_key in theta_0:
217
+ theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
218
+
219
+ del vae_dict
220
+
221
+ if save_as_half and not theta_func2:
222
+ for key in theta_0.keys():
223
+ theta_0[key] = to_half(theta_0[key], save_as_half)
224
+
225
+ if discard_weights:
226
+ regex = re.compile(discard_weights)
227
+ for key in list(theta_0):
228
+ if re.search(regex, key):
229
+ theta_0.pop(key, None)
230
+
231
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
232
+
233
+ filename = filename_generator() if custom_name == '' else custom_name
234
+ filename += ".inpainting" if result_is_inpainting_model else ""
235
+ filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
236
+ filename += "." + checkpoint_format
237
+
238
+ output_modelname = os.path.join(ckpt_dir, filename)
239
+
240
+ shared.state.nextjob()
241
+ shared.state.textinfo = "Saving"
242
+ print(f"Saving to {output_modelname}...")
243
+
244
+ metadata = None
245
+
246
+ if save_metadata:
247
+ metadata = {"format": "pt"}
248
+
249
+ merge_recipe = {
250
+ "type": "webui", # indicate this model was merged with webui's built-in merger
251
+ "primary_model_hash": primary_model_info.sha256,
252
+ "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
253
+ "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
254
+ "interp_method": interp_method,
255
+ "multiplier": multiplier,
256
+ "save_as_half": save_as_half,
257
+ "custom_name": custom_name,
258
+ "config_source": config_source,
259
+ "bake_in_vae": bake_in_vae,
260
+ "discard_weights": discard_weights,
261
+ "is_inpainting": result_is_inpainting_model,
262
+ "is_instruct_pix2pix": result_is_instruct_pix2pix_model
263
+ }
264
+ metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
265
+
266
+ sd_merge_models = {}
267
+
268
+ def add_model_metadata(checkpoint_info):
269
+ checkpoint_info.calculate_shorthash()
270
+ sd_merge_models[checkpoint_info.sha256] = {
271
+ "name": checkpoint_info.name,
272
+ "legacy_hash": checkpoint_info.hash,
273
+ "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
274
+ }
275
+
276
+ sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
277
+
278
+ add_model_metadata(primary_model_info)
279
+ if secondary_model_info:
280
+ add_model_metadata(secondary_model_info)
281
+ if tertiary_model_info:
282
+ add_model_metadata(tertiary_model_info)
283
+
284
+ metadata["sd_merge_models"] = json.dumps(sd_merge_models)
285
+
286
+ _, extension = os.path.splitext(output_modelname)
287
+ if extension.lower() == ".safetensors":
288
+ safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
289
+ else:
290
+ torch.save(theta_0, output_modelname)
291
+
292
+ sd_models.list_models()
293
+ created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
294
+ if created_model:
295
+ created_model.calculate_shorthash()
296
+
297
+ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
298
+
299
+ print(f"Checkpoint saved to {output_modelname}.")
300
+ shared.state.textinfo = "Checkpoint saved"
301
+ shared.state.end()
302
+
303
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
modules/face_restoration.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import shared
2
+
3
+
4
+ class FaceRestoration:
5
+ def name(self):
6
+ return "None"
7
+
8
+ def restore(self, np_image):
9
+ return np_image
10
+
11
+
12
+ def restore_faces(np_image):
13
+ face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
14
+ if len(face_restorers) == 0:
15
+ return np_image
16
+
17
+ face_restorer = face_restorers[0]
18
+
19
+ return face_restorer.restore(np_image)
modules/generation_parameters_copypaste.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import json
4
+ import os
5
+ import re
6
+
7
+ import gradio as gr
8
+ from modules.paths import data_path
9
+ from modules import shared, ui_tempdir, script_callbacks
10
+ from PIL import Image
11
+
12
+ re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
13
+ re_param = re.compile(re_param_code)
14
+ re_imagesize = re.compile(r"^(\d+)x(\d+)$")
15
+ re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
16
+ type_of_gr_update = type(gr.update())
17
+
18
+ paste_fields = {}
19
+ registered_param_bindings = []
20
+
21
+
22
+ class ParamBinding:
23
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
24
+ self.paste_button = paste_button
25
+ self.tabname = tabname
26
+ self.source_text_component = source_text_component
27
+ self.source_image_component = source_image_component
28
+ self.source_tabname = source_tabname
29
+ self.override_settings_component = override_settings_component
30
+ self.paste_field_names = paste_field_names or []
31
+
32
+
33
+ def reset():
34
+ paste_fields.clear()
35
+
36
+
37
+ def quote(text):
38
+ if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
39
+ return text
40
+
41
+ return json.dumps(text, ensure_ascii=False)
42
+
43
+
44
+ def unquote(text):
45
+ if len(text) == 0 or text[0] != '"' or text[-1] != '"':
46
+ return text
47
+
48
+ try:
49
+ return json.loads(text)
50
+ except Exception:
51
+ return text
52
+
53
+
54
+ def image_from_url_text(filedata):
55
+ if filedata is None:
56
+ return None
57
+
58
+ if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
59
+ filedata = filedata[0]
60
+
61
+ if type(filedata) == dict and filedata.get("is_file", False):
62
+ filename = filedata["name"]
63
+ is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
64
+ assert is_in_right_dir, 'trying to open image file outside of allowed directories'
65
+
66
+ filename = filename.rsplit('?', 1)[0]
67
+ return Image.open(filename)
68
+
69
+ if type(filedata) == list:
70
+ if len(filedata) == 0:
71
+ return None
72
+
73
+ filedata = filedata[0]
74
+
75
+ if filedata.startswith("data:image/png;base64,"):
76
+ filedata = filedata[len("data:image/png;base64,"):]
77
+
78
+ filedata = base64.decodebytes(filedata.encode('utf-8'))
79
+ image = Image.open(io.BytesIO(filedata))
80
+ return image
81
+
82
+
83
+ def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
84
+ paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
85
+
86
+ # backwards compatibility for existing extensions
87
+ import modules.ui
88
+ if tabname == 'txt2img':
89
+ modules.ui.txt2img_paste_fields = fields
90
+ elif tabname == 'img2img':
91
+ modules.ui.img2img_paste_fields = fields
92
+
93
+
94
+ def create_buttons(tabs_list):
95
+ buttons = {}
96
+ for tab in tabs_list:
97
+ buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
98
+ return buttons
99
+
100
+
101
+ def bind_buttons(buttons, send_image, send_generate_info):
102
+ """old function for backwards compatibility; do not use this, use register_paste_params_button"""
103
+ for tabname, button in buttons.items():
104
+ source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
105
+ source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
106
+
107
+ register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
108
+
109
+
110
+ def register_paste_params_button(binding: ParamBinding):
111
+ registered_param_bindings.append(binding)
112
+
113
+
114
+ def connect_paste_params_buttons():
115
+ binding: ParamBinding
116
+ for binding in registered_param_bindings:
117
+ destination_image_component = paste_fields[binding.tabname]["init_img"]
118
+ fields = paste_fields[binding.tabname]["fields"]
119
+ override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
120
+
121
+ destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
122
+ destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
123
+
124
+ if binding.source_image_component and destination_image_component:
125
+ if isinstance(binding.source_image_component, gr.Gallery):
126
+ func = send_image_and_dimensions if destination_width_component else image_from_url_text
127
+ jsfunc = "extract_image_from_gallery"
128
+ else:
129
+ func = send_image_and_dimensions if destination_width_component else lambda x: x
130
+ jsfunc = None
131
+
132
+ binding.paste_button.click(
133
+ fn=func,
134
+ _js=jsfunc,
135
+ inputs=[binding.source_image_component],
136
+ outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
137
+ show_progress=False,
138
+ )
139
+
140
+ if binding.source_text_component is not None and fields is not None:
141
+ connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
142
+
143
+ if binding.source_tabname is not None and fields is not None:
144
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
145
+ binding.paste_button.click(
146
+ fn=lambda *x: x,
147
+ inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
148
+ outputs=[field for field, name in fields if name in paste_field_names],
149
+ show_progress=False,
150
+ )
151
+
152
+ binding.paste_button.click(
153
+ fn=None,
154
+ _js=f"switch_to_{binding.tabname}",
155
+ inputs=None,
156
+ outputs=None,
157
+ show_progress=False,
158
+ )
159
+
160
+
161
+ def send_image_and_dimensions(x):
162
+ if isinstance(x, Image.Image):
163
+ img = x
164
+ else:
165
+ img = image_from_url_text(x)
166
+
167
+ if shared.opts.send_size and isinstance(img, Image.Image):
168
+ w = img.width
169
+ h = img.height
170
+ else:
171
+ w = gr.update()
172
+ h = gr.update()
173
+
174
+ return img, w, h
175
+
176
+
177
+ def restore_old_hires_fix_params(res):
178
+ """for infotexts that specify old First pass size parameter, convert it into
179
+ width, height, and hr scale"""
180
+
181
+ firstpass_width = res.get('First pass size-1', None)
182
+ firstpass_height = res.get('First pass size-2', None)
183
+
184
+ if shared.opts.use_old_hires_fix_width_height:
185
+ hires_width = int(res.get("Hires resize-1", 0))
186
+ hires_height = int(res.get("Hires resize-2", 0))
187
+
188
+ if hires_width and hires_height:
189
+ res['Size-1'] = hires_width
190
+ res['Size-2'] = hires_height
191
+ return
192
+
193
+ if firstpass_width is None or firstpass_height is None:
194
+ return
195
+
196
+ firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
197
+ width = int(res.get("Size-1", 512))
198
+ height = int(res.get("Size-2", 512))
199
+
200
+ if firstpass_width == 0 or firstpass_height == 0:
201
+ from modules import processing
202
+ firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
203
+
204
+ res['Size-1'] = firstpass_width
205
+ res['Size-2'] = firstpass_height
206
+ res['Hires resize-1'] = width
207
+ res['Hires resize-2'] = height
208
+
209
+
210
+ def parse_generation_parameters(x: str):
211
+ """parses generation parameters string, the one you see in text field under the picture in UI:
212
+ ```
213
+ girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
214
+ Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
215
+ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
216
+ ```
217
+
218
+ returns a dict with field values
219
+ """
220
+
221
+ res = {}
222
+
223
+ prompt = ""
224
+ negative_prompt = ""
225
+
226
+ done_with_prompt = False
227
+
228
+ *lines, lastline = x.strip().split("\n")
229
+ if len(re_param.findall(lastline)) < 3:
230
+ lines.append(lastline)
231
+ lastline = ''
232
+
233
+ for line in lines:
234
+ line = line.strip()
235
+ if line.startswith("Negative prompt:"):
236
+ done_with_prompt = True
237
+ line = line[16:].strip()
238
+ if done_with_prompt:
239
+ negative_prompt += ("" if negative_prompt == "" else "\n") + line
240
+ else:
241
+ prompt += ("" if prompt == "" else "\n") + line
242
+
243
+ if shared.opts.infotext_styles != "Ignore":
244
+ found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
245
+
246
+ if shared.opts.infotext_styles == "Apply":
247
+ res["Styles array"] = found_styles
248
+ elif shared.opts.infotext_styles == "Apply if any" and found_styles:
249
+ res["Styles array"] = found_styles
250
+
251
+ res["Prompt"] = prompt
252
+ res["Negative prompt"] = negative_prompt
253
+
254
+ for k, v in re_param.findall(lastline):
255
+ try:
256
+ if v[0] == '"' and v[-1] == '"':
257
+ v = unquote(v)
258
+
259
+ m = re_imagesize.match(v)
260
+ if m is not None:
261
+ res[f"{k}-1"] = m.group(1)
262
+ res[f"{k}-2"] = m.group(2)
263
+ else:
264
+ res[k] = v
265
+ except Exception:
266
+ print(f"Error parsing \"{k}: {v}\"")
267
+
268
+ # Missing CLIP skip means it was set to 1 (the default)
269
+ if "Clip skip" not in res:
270
+ res["Clip skip"] = "1"
271
+
272
+ hypernet = res.get("Hypernet", None)
273
+ if hypernet is not None:
274
+ res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
275
+
276
+ if "Hires resize-1" not in res:
277
+ res["Hires resize-1"] = 0
278
+ res["Hires resize-2"] = 0
279
+
280
+ if "Hires sampler" not in res:
281
+ res["Hires sampler"] = "Use same sampler"
282
+
283
+ if "Hires prompt" not in res:
284
+ res["Hires prompt"] = ""
285
+
286
+ if "Hires negative prompt" not in res:
287
+ res["Hires negative prompt"] = ""
288
+
289
+ restore_old_hires_fix_params(res)
290
+
291
+ # Missing RNG means the default was set, which is GPU RNG
292
+ if "RNG" not in res:
293
+ res["RNG"] = "GPU"
294
+
295
+ if "Schedule type" not in res:
296
+ res["Schedule type"] = "Automatic"
297
+
298
+ if "Schedule max sigma" not in res:
299
+ res["Schedule max sigma"] = 0
300
+
301
+ if "Schedule min sigma" not in res:
302
+ res["Schedule min sigma"] = 0
303
+
304
+ if "Schedule rho" not in res:
305
+ res["Schedule rho"] = 0
306
+
307
+ return res
308
+
309
+
310
+ infotext_to_setting_name_mapping = [
311
+ ('Clip skip', 'CLIP_stop_at_last_layers', ),
312
+ ('Conditional mask weight', 'inpainting_mask_weight'),
313
+ ('Model hash', 'sd_model_checkpoint'),
314
+ ('ENSD', 'eta_noise_seed_delta'),
315
+ ('Schedule type', 'k_sched_type'),
316
+ ('Schedule max sigma', 'sigma_max'),
317
+ ('Schedule min sigma', 'sigma_min'),
318
+ ('Schedule rho', 'rho'),
319
+ ('Noise multiplier', 'initial_noise_multiplier'),
320
+ ('Eta', 'eta_ancestral'),
321
+ ('Eta DDIM', 'eta_ddim'),
322
+ ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
323
+ ('UniPC variant', 'uni_pc_variant'),
324
+ ('UniPC skip type', 'uni_pc_skip_type'),
325
+ ('UniPC order', 'uni_pc_order'),
326
+ ('UniPC lower order final', 'uni_pc_lower_order_final'),
327
+ ('Token merging ratio', 'token_merging_ratio'),
328
+ ('Token merging ratio hr', 'token_merging_ratio_hr'),
329
+ ('RNG', 'randn_source'),
330
+ ('NGMS', 's_min_uncond'),
331
+ ('Pad conds', 'pad_cond_uncond'),
332
+ ]
333
+
334
+
335
+ def create_override_settings_dict(text_pairs):
336
+ """creates processing's override_settings parameters from gradio's multiselect
337
+
338
+ Example input:
339
+ ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
340
+
341
+ Example output:
342
+ {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
343
+ """
344
+
345
+ res = {}
346
+
347
+ params = {}
348
+ for pair in text_pairs:
349
+ k, v = pair.split(":", maxsplit=1)
350
+
351
+ params[k] = v.strip()
352
+
353
+ for param_name, setting_name in infotext_to_setting_name_mapping:
354
+ value = params.get(param_name, None)
355
+
356
+ if value is None:
357
+ continue
358
+
359
+ res[setting_name] = shared.opts.cast_value(setting_name, value)
360
+
361
+ return res
362
+
363
+
364
+ def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
365
+ def paste_func(prompt):
366
+ if not prompt and not shared.cmd_opts.hide_ui_dir_config:
367
+ filename = os.path.join(data_path, "params.txt")
368
+ if os.path.exists(filename):
369
+ with open(filename, "r", encoding="utf8") as file:
370
+ prompt = file.read()
371
+
372
+ params = parse_generation_parameters(prompt)
373
+ script_callbacks.infotext_pasted_callback(prompt, params)
374
+ res = []
375
+
376
+ for output, key in paste_fields:
377
+ if callable(key):
378
+ v = key(params)
379
+ else:
380
+ v = params.get(key, None)
381
+
382
+ if v is None:
383
+ res.append(gr.update())
384
+ elif isinstance(v, type_of_gr_update):
385
+ res.append(v)
386
+ else:
387
+ try:
388
+ valtype = type(output.value)
389
+
390
+ if valtype == bool and v == "False":
391
+ val = False
392
+ else:
393
+ val = valtype(v)
394
+
395
+ res.append(gr.update(value=val))
396
+ except Exception:
397
+ res.append(gr.update())
398
+
399
+ return res
400
+
401
+ if override_settings_component is not None:
402
+ def paste_settings(params):
403
+ vals = {}
404
+
405
+ for param_name, setting_name in infotext_to_setting_name_mapping:
406
+ v = params.get(param_name, None)
407
+ if v is None:
408
+ continue
409
+
410
+ if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
411
+ continue
412
+
413
+ v = shared.opts.cast_value(setting_name, v)
414
+ current_value = getattr(shared.opts, setting_name, None)
415
+
416
+ if v == current_value:
417
+ continue
418
+
419
+ vals[param_name] = v
420
+
421
+ vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
422
+
423
+ return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
424
+
425
+ paste_fields = paste_fields + [(override_settings_component, paste_settings)]
426
+
427
+ button.click(
428
+ fn=paste_func,
429
+ inputs=[input_comp],
430
+ outputs=[x[0] for x in paste_fields],
431
+ show_progress=False,
432
+ )
433
+ button.click(
434
+ fn=None,
435
+ _js=f"recalculate_prompts_{tabname}",
436
+ inputs=[],
437
+ outputs=[],
438
+ show_progress=False,
439
+ )
modules/gfpgan_model.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import facexlib
4
+ import gfpgan
5
+
6
+ import modules.face_restoration
7
+ from modules import paths, shared, devices, modelloader, errors
8
+
9
+ model_dir = "GFPGAN"
10
+ user_path = None
11
+ model_path = os.path.join(paths.models_path, model_dir)
12
+ model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
13
+ have_gfpgan = False
14
+ loaded_gfpgan_model = None
15
+
16
+
17
+ def gfpgann():
18
+ global loaded_gfpgan_model
19
+ global model_path
20
+ if loaded_gfpgan_model is not None:
21
+ loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
22
+ return loaded_gfpgan_model
23
+
24
+ if gfpgan_constructor is None:
25
+ return None
26
+
27
+ models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
28
+ if len(models) == 1 and models[0].startswith("http"):
29
+ model_file = models[0]
30
+ elif len(models) != 0:
31
+ latest_file = max(models, key=os.path.getctime)
32
+ model_file = latest_file
33
+ else:
34
+ print("Unable to load gfpgan model!")
35
+ return None
36
+ if hasattr(facexlib.detection.retinaface, 'device'):
37
+ facexlib.detection.retinaface.device = devices.device_gfpgan
38
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
39
+ loaded_gfpgan_model = model
40
+
41
+ return model
42
+
43
+
44
+ def send_model_to(model, device):
45
+ model.gfpgan.to(device)
46
+ model.face_helper.face_det.to(device)
47
+ model.face_helper.face_parse.to(device)
48
+
49
+
50
+ def gfpgan_fix_faces(np_image):
51
+ model = gfpgann()
52
+ if model is None:
53
+ return np_image
54
+
55
+ send_model_to(model, devices.device_gfpgan)
56
+
57
+ np_image_bgr = np_image[:, :, ::-1]
58
+ cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
59
+ np_image = gfpgan_output_bgr[:, :, ::-1]
60
+
61
+ model.face_helper.clean_all()
62
+
63
+ if shared.opts.face_restoration_unload:
64
+ send_model_to(model, devices.cpu)
65
+
66
+ return np_image
67
+
68
+
69
+ gfpgan_constructor = None
70
+
71
+
72
+ def setup_model(dirname):
73
+ try:
74
+ os.makedirs(model_path, exist_ok=True)
75
+ from gfpgan import GFPGANer
76
+ from facexlib import detection, parsing # noqa: F401
77
+ global user_path
78
+ global have_gfpgan
79
+ global gfpgan_constructor
80
+
81
+ load_file_from_url_orig = gfpgan.utils.load_file_from_url
82
+ facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
83
+ facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
84
+
85
+ def my_load_file_from_url(**kwargs):
86
+ return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
87
+
88
+ def facex_load_file_from_url(**kwargs):
89
+ return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
90
+
91
+ def facex_load_file_from_url2(**kwargs):
92
+ return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
93
+
94
+ gfpgan.utils.load_file_from_url = my_load_file_from_url
95
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
96
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
97
+ user_path = dirname
98
+ have_gfpgan = True
99
+ gfpgan_constructor = GFPGANer
100
+
101
+ class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
102
+ def name(self):
103
+ return "GFPGAN"
104
+
105
+ def restore(self, np_image):
106
+ return gfpgan_fix_faces(np_image)
107
+
108
+ shared.face_restorers.append(FaceRestorerGFPGAN())
109
+ except Exception:
110
+ errors.report("Error setting up GFPGAN", exc_info=True)
modules/gitpython_hack.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ import subprocess
5
+
6
+ import git
7
+
8
+
9
+ class Git(git.Git):
10
+ """
11
+ Git subclassed to never use persistent processes.
12
+ """
13
+
14
+ def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):
15
+ raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})")
16
+
17
+ def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
18
+ ret = subprocess.check_output(
19
+ [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"],
20
+ input=self._prepare_ref(ref),
21
+ cwd=self._working_dir,
22
+ timeout=2,
23
+ )
24
+ return self._parse_object_header(ret)
25
+
26
+ def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
27
+ # Not really streaming, per se; this buffers the entire object in memory.
28
+ # Shouldn't be a problem for our use case, since we're only using this for
29
+ # object headers (commit objects).
30
+ ret = subprocess.check_output(
31
+ [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"],
32
+ input=self._prepare_ref(ref),
33
+ cwd=self._working_dir,
34
+ timeout=30,
35
+ )
36
+ bio = io.BytesIO(ret)
37
+ hexsha, typename, size = self._parse_object_header(bio.readline())
38
+ return (hexsha, typename, size, self.CatFileContentStream(size, bio))
39
+
40
+
41
+ class Repo(git.Repo):
42
+ GitCommandWrapperType = Git
modules/hashes.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os.path
3
+
4
+ from modules import shared
5
+ import modules.cache
6
+
7
+ dump_cache = modules.cache.dump_cache
8
+ cache = modules.cache.cache
9
+
10
+
11
+ def calculate_sha256(filename):
12
+ hash_sha256 = hashlib.sha256()
13
+ blksize = 1024 * 1024
14
+
15
+ with open(filename, "rb") as f:
16
+ for chunk in iter(lambda: f.read(blksize), b""):
17
+ hash_sha256.update(chunk)
18
+
19
+ return hash_sha256.hexdigest()
20
+
21
+
22
+ def sha256_from_cache(filename, title, use_addnet_hash=False):
23
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
24
+ ondisk_mtime = os.path.getmtime(filename)
25
+
26
+ if title not in hashes:
27
+ return None
28
+
29
+ cached_sha256 = hashes[title].get("sha256", None)
30
+ cached_mtime = hashes[title].get("mtime", 0)
31
+
32
+ if ondisk_mtime > cached_mtime or cached_sha256 is None:
33
+ return None
34
+
35
+ return cached_sha256
36
+
37
+
38
+ def sha256(filename, title, use_addnet_hash=False):
39
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
40
+
41
+ sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
42
+ if sha256_value is not None:
43
+ return sha256_value
44
+
45
+ if shared.cmd_opts.no_hashing:
46
+ return None
47
+
48
+ print(f"Calculating sha256 for {filename}: ", end='')
49
+ if use_addnet_hash:
50
+ with open(filename, "rb") as file:
51
+ sha256_value = addnet_hash_safetensors(file)
52
+ else:
53
+ sha256_value = calculate_sha256(filename)
54
+ print(f"{sha256_value}")
55
+
56
+ hashes[title] = {
57
+ "mtime": os.path.getmtime(filename),
58
+ "sha256": sha256_value,
59
+ }
60
+
61
+ dump_cache()
62
+
63
+ return sha256_value
64
+
65
+
66
+ def addnet_hash_safetensors(b):
67
+ """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
68
+ hash_sha256 = hashlib.sha256()
69
+ blksize = 1024 * 1024
70
+
71
+ b.seek(0)
72
+ header = b.read(8)
73
+ n = int.from_bytes(header, "little")
74
+
75
+ offset = n + 8
76
+ b.seek(offset)
77
+ for chunk in iter(lambda: b.read(blksize), b""):
78
+ hash_sha256.update(chunk)
79
+
80
+ return hash_sha256.hexdigest()
81
+
modules/hypernetworks/__pycache__/hypernetwork.cpython-310.pyc ADDED
Binary file (21.5 kB). View file
 
modules/hypernetworks/__pycache__/ui.cpython-310.pyc ADDED
Binary file (1.69 kB). View file
 
modules/hypernetworks/hypernetwork.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import glob
3
+ import html
4
+ import os
5
+ import inspect
6
+ from contextlib import closing
7
+
8
+ import modules.textual_inversion.dataset
9
+ import torch
10
+ import tqdm
11
+ from einops import rearrange, repeat
12
+ from ldm.util import default
13
+ from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
14
+ from modules.textual_inversion import textual_inversion, logging
15
+ from modules.textual_inversion.learn_schedule import LearnRateScheduler
16
+ from torch import einsum
17
+ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
18
+
19
+ from collections import deque
20
+ from statistics import stdev, mean
21
+
22
+
23
+ optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
24
+
25
+ class HypernetworkModule(torch.nn.Module):
26
+ activation_dict = {
27
+ "linear": torch.nn.Identity,
28
+ "relu": torch.nn.ReLU,
29
+ "leakyrelu": torch.nn.LeakyReLU,
30
+ "elu": torch.nn.ELU,
31
+ "swish": torch.nn.Hardswish,
32
+ "tanh": torch.nn.Tanh,
33
+ "sigmoid": torch.nn.Sigmoid,
34
+ }
35
+ activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
36
+
37
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
38
+ add_layer_norm=False, activate_output=False, dropout_structure=None):
39
+ super().__init__()
40
+
41
+ self.multiplier = 1.0
42
+
43
+ assert layer_structure is not None, "layer_structure must not be None"
44
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
45
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
46
+
47
+ linears = []
48
+ for i in range(len(layer_structure) - 1):
49
+
50
+ # Add a fully-connected layer
51
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
52
+
53
+ # Add an activation func except last layer
54
+ if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
55
+ pass
56
+ elif activation_func in self.activation_dict:
57
+ linears.append(self.activation_dict[activation_func]())
58
+ else:
59
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
60
+
61
+ # Add layer normalization
62
+ if add_layer_norm:
63
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
64
+
65
+ # Everything should be now parsed into dropout structure, and applied here.
66
+ # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
67
+ if dropout_structure is not None and dropout_structure[i+1] > 0:
68
+ assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
69
+ linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
70
+ # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
71
+
72
+ self.linear = torch.nn.Sequential(*linears)
73
+
74
+ if state_dict is not None:
75
+ self.fix_old_state_dict(state_dict)
76
+ self.load_state_dict(state_dict)
77
+ else:
78
+ for layer in self.linear:
79
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
80
+ w, b = layer.weight.data, layer.bias.data
81
+ if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
82
+ normal_(w, mean=0.0, std=0.01)
83
+ normal_(b, mean=0.0, std=0)
84
+ elif weight_init == 'XavierUniform':
85
+ xavier_uniform_(w)
86
+ zeros_(b)
87
+ elif weight_init == 'XavierNormal':
88
+ xavier_normal_(w)
89
+ zeros_(b)
90
+ elif weight_init == 'KaimingUniform':
91
+ kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
92
+ zeros_(b)
93
+ elif weight_init == 'KaimingNormal':
94
+ kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
95
+ zeros_(b)
96
+ else:
97
+ raise KeyError(f"Key {weight_init} is not defined as initialization!")
98
+ self.to(devices.device)
99
+
100
+ def fix_old_state_dict(self, state_dict):
101
+ changes = {
102
+ 'linear1.bias': 'linear.0.bias',
103
+ 'linear1.weight': 'linear.0.weight',
104
+ 'linear2.bias': 'linear.1.bias',
105
+ 'linear2.weight': 'linear.1.weight',
106
+ }
107
+
108
+ for fr, to in changes.items():
109
+ x = state_dict.get(fr, None)
110
+ if x is None:
111
+ continue
112
+
113
+ del state_dict[fr]
114
+ state_dict[to] = x
115
+
116
+ def forward(self, x):
117
+ return x + self.linear(x) * (self.multiplier if not self.training else 1)
118
+
119
+ def trainables(self):
120
+ layer_structure = []
121
+ for layer in self.linear:
122
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
123
+ layer_structure += [layer.weight, layer.bias]
124
+ return layer_structure
125
+
126
+
127
+ #param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
128
+ def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
129
+ if layer_structure is None:
130
+ layer_structure = [1, 2, 1]
131
+ if not use_dropout:
132
+ return [0] * len(layer_structure)
133
+ dropout_values = [0]
134
+ dropout_values.extend([0.3] * (len(layer_structure) - 3))
135
+ if last_layer_dropout:
136
+ dropout_values.append(0.3)
137
+ else:
138
+ dropout_values.append(0)
139
+ dropout_values.append(0)
140
+ return dropout_values
141
+
142
+
143
+ class Hypernetwork:
144
+ filename = None
145
+ name = None
146
+
147
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
148
+ self.filename = None
149
+ self.name = name
150
+ self.layers = {}
151
+ self.step = 0
152
+ self.sd_checkpoint = None
153
+ self.sd_checkpoint_name = None
154
+ self.layer_structure = layer_structure
155
+ self.activation_func = activation_func
156
+ self.weight_init = weight_init
157
+ self.add_layer_norm = add_layer_norm
158
+ self.use_dropout = use_dropout
159
+ self.activate_output = activate_output
160
+ self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
161
+ self.dropout_structure = kwargs.get('dropout_structure', None)
162
+ if self.dropout_structure is None:
163
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
164
+ self.optimizer_name = None
165
+ self.optimizer_state_dict = None
166
+ self.optional_info = None
167
+
168
+ for size in enable_sizes or []:
169
+ self.layers[size] = (
170
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
171
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
172
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
173
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
174
+ )
175
+ self.eval()
176
+
177
+ def weights(self):
178
+ res = []
179
+ for layers in self.layers.values():
180
+ for layer in layers:
181
+ res += layer.parameters()
182
+ return res
183
+
184
+ def train(self, mode=True):
185
+ for layers in self.layers.values():
186
+ for layer in layers:
187
+ layer.train(mode=mode)
188
+ for param in layer.parameters():
189
+ param.requires_grad = mode
190
+
191
+ def to(self, device):
192
+ for layers in self.layers.values():
193
+ for layer in layers:
194
+ layer.to(device)
195
+
196
+ return self
197
+
198
+ def set_multiplier(self, multiplier):
199
+ for layers in self.layers.values():
200
+ for layer in layers:
201
+ layer.multiplier = multiplier
202
+
203
+ return self
204
+
205
+ def eval(self):
206
+ for layers in self.layers.values():
207
+ for layer in layers:
208
+ layer.eval()
209
+ for param in layer.parameters():
210
+ param.requires_grad = False
211
+
212
+ def save(self, filename):
213
+ state_dict = {}
214
+ optimizer_saved_dict = {}
215
+
216
+ for k, v in self.layers.items():
217
+ state_dict[k] = (v[0].state_dict(), v[1].state_dict())
218
+
219
+ state_dict['step'] = self.step
220
+ state_dict['name'] = self.name
221
+ state_dict['layer_structure'] = self.layer_structure
222
+ state_dict['activation_func'] = self.activation_func
223
+ state_dict['is_layer_norm'] = self.add_layer_norm
224
+ state_dict['weight_initialization'] = self.weight_init
225
+ state_dict['sd_checkpoint'] = self.sd_checkpoint
226
+ state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
227
+ state_dict['activate_output'] = self.activate_output
228
+ state_dict['use_dropout'] = self.use_dropout
229
+ state_dict['dropout_structure'] = self.dropout_structure
230
+ state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
231
+ state_dict['optional_info'] = self.optional_info if self.optional_info else None
232
+
233
+ if self.optimizer_name is not None:
234
+ optimizer_saved_dict['optimizer_name'] = self.optimizer_name
235
+
236
+ torch.save(state_dict, filename)
237
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict:
238
+ optimizer_saved_dict['hash'] = self.shorthash()
239
+ optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
240
+ torch.save(optimizer_saved_dict, filename + '.optim')
241
+
242
+ def load(self, filename):
243
+ self.filename = filename
244
+ if self.name is None:
245
+ self.name = os.path.splitext(os.path.basename(filename))[0]
246
+
247
+ state_dict = torch.load(filename, map_location='cpu')
248
+
249
+ self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
250
+ self.optional_info = state_dict.get('optional_info', None)
251
+ self.activation_func = state_dict.get('activation_func', None)
252
+ self.weight_init = state_dict.get('weight_initialization', 'Normal')
253
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
254
+ self.dropout_structure = state_dict.get('dropout_structure', None)
255
+ self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
256
+ self.activate_output = state_dict.get('activate_output', True)
257
+ self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
258
+ # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
259
+ if self.dropout_structure is None:
260
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
261
+
262
+ if shared.opts.print_hypernet_extra:
263
+ if self.optional_info is not None:
264
+ print(f" INFO:\n {self.optional_info}\n")
265
+
266
+ print(f" Layer structure: {self.layer_structure}")
267
+ print(f" Activation function: {self.activation_func}")
268
+ print(f" Weight initialization: {self.weight_init}")
269
+ print(f" Layer norm: {self.add_layer_norm}")
270
+ print(f" Dropout usage: {self.use_dropout}" )
271
+ print(f" Activate last layer: {self.activate_output}")
272
+ print(f" Dropout structure: {self.dropout_structure}")
273
+
274
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
275
+
276
+ if self.shorthash() == optimizer_saved_dict.get('hash', None):
277
+ self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
278
+ else:
279
+ self.optimizer_state_dict = None
280
+ if self.optimizer_state_dict:
281
+ self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
282
+ if shared.opts.print_hypernet_extra:
283
+ print("Loaded existing optimizer from checkpoint")
284
+ print(f"Optimizer name is {self.optimizer_name}")
285
+ else:
286
+ self.optimizer_name = "AdamW"
287
+ if shared.opts.print_hypernet_extra:
288
+ print("No saved optimizer exists in checkpoint")
289
+
290
+ for size, sd in state_dict.items():
291
+ if type(size) == int:
292
+ self.layers[size] = (
293
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
294
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
295
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
296
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
297
+ )
298
+
299
+ self.name = state_dict.get('name', self.name)
300
+ self.step = state_dict.get('step', 0)
301
+ self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
302
+ self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
303
+ self.eval()
304
+
305
+ def shorthash(self):
306
+ sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
307
+
308
+ return sha256[0:10] if sha256 else None
309
+
310
+
311
+ def list_hypernetworks(path):
312
+ res = {}
313
+ for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True), key=str.lower):
314
+ name = os.path.splitext(os.path.basename(filename))[0]
315
+ # Prevent a hypothetical "None.pt" from being listed.
316
+ if name != "None":
317
+ res[name] = filename
318
+ return res
319
+
320
+
321
+ def load_hypernetwork(name):
322
+ path = shared.hypernetworks.get(name, None)
323
+
324
+ if path is None:
325
+ return None
326
+
327
+ try:
328
+ hypernetwork = Hypernetwork()
329
+ hypernetwork.load(path)
330
+ return hypernetwork
331
+ except Exception:
332
+ errors.report(f"Error loading hypernetwork {path}", exc_info=True)
333
+ return None
334
+
335
+
336
+ def load_hypernetworks(names, multipliers=None):
337
+ already_loaded = {}
338
+
339
+ for hypernetwork in shared.loaded_hypernetworks:
340
+ if hypernetwork.name in names:
341
+ already_loaded[hypernetwork.name] = hypernetwork
342
+
343
+ shared.loaded_hypernetworks.clear()
344
+
345
+ for i, name in enumerate(names):
346
+ hypernetwork = already_loaded.get(name, None)
347
+ if hypernetwork is None:
348
+ hypernetwork = load_hypernetwork(name)
349
+
350
+ if hypernetwork is None:
351
+ continue
352
+
353
+ hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
354
+ shared.loaded_hypernetworks.append(hypernetwork)
355
+
356
+
357
+ def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
358
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
359
+
360
+ if hypernetwork_layers is None:
361
+ return context_k, context_v
362
+
363
+ if layer is not None:
364
+ layer.hyper_k = hypernetwork_layers[0]
365
+ layer.hyper_v = hypernetwork_layers[1]
366
+
367
+ context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
368
+ context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
369
+ return context_k, context_v
370
+
371
+
372
+ def apply_hypernetworks(hypernetworks, context, layer=None):
373
+ context_k = context
374
+ context_v = context
375
+ for hypernetwork in hypernetworks:
376
+ context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
377
+
378
+ return context_k, context_v
379
+
380
+
381
+ def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
382
+ h = self.heads
383
+
384
+ q = self.to_q(x)
385
+ context = default(context, x)
386
+
387
+ context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
388
+ k = self.to_k(context_k)
389
+ v = self.to_v(context_v)
390
+
391
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
392
+
393
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
394
+
395
+ if mask is not None:
396
+ mask = rearrange(mask, 'b ... -> b (...)')
397
+ max_neg_value = -torch.finfo(sim.dtype).max
398
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
399
+ sim.masked_fill_(~mask, max_neg_value)
400
+
401
+ # attention, what we cannot get enough of
402
+ attn = sim.softmax(dim=-1)
403
+
404
+ out = einsum('b i j, b j d -> b i d', attn, v)
405
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
406
+ return self.to_out(out)
407
+
408
+
409
+ def stack_conds(conds):
410
+ if len(conds) == 1:
411
+ return torch.stack(conds)
412
+
413
+ # same as in reconstruct_multicond_batch
414
+ token_count = max([x.shape[0] for x in conds])
415
+ for i in range(len(conds)):
416
+ if conds[i].shape[0] != token_count:
417
+ last_vector = conds[i][-1:]
418
+ last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
419
+ conds[i] = torch.vstack([conds[i], last_vector_repeated])
420
+
421
+ return torch.stack(conds)
422
+
423
+
424
+ def statistics(data):
425
+ if len(data) < 2:
426
+ std = 0
427
+ else:
428
+ std = stdev(data)
429
+ total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
430
+ recent_data = data[-32:]
431
+ if len(recent_data) < 2:
432
+ std = 0
433
+ else:
434
+ std = stdev(recent_data)
435
+ recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
436
+ return total_information, recent_information
437
+
438
+
439
+ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
440
+ # Remove illegal characters from name.
441
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
442
+ assert name, "Name cannot be empty!"
443
+
444
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
445
+ if not overwrite_old:
446
+ assert not os.path.exists(fn), f"file {fn} already exists"
447
+
448
+ if type(layer_structure) == str:
449
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
450
+
451
+ if use_dropout and dropout_structure and type(dropout_structure) == str:
452
+ dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
453
+ else:
454
+ dropout_structure = [0] * len(layer_structure)
455
+
456
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
457
+ name=name,
458
+ enable_sizes=[int(x) for x in enable_sizes],
459
+ layer_structure=layer_structure,
460
+ activation_func=activation_func,
461
+ weight_init=weight_init,
462
+ add_layer_norm=add_layer_norm,
463
+ use_dropout=use_dropout,
464
+ dropout_structure=dropout_structure
465
+ )
466
+ hypernet.save(fn)
467
+
468
+ shared.reload_hypernetworks()
469
+
470
+
471
+ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
472
+ # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
473
+ from modules import images
474
+
475
+ save_hypernetwork_every = save_hypernetwork_every or 0
476
+ create_image_every = create_image_every or 0
477
+ template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
478
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
479
+ template_file = template_file.path
480
+
481
+ path = shared.hypernetworks.get(hypernetwork_name, None)
482
+ hypernetwork = Hypernetwork()
483
+ hypernetwork.load(path)
484
+ shared.loaded_hypernetworks = [hypernetwork]
485
+
486
+ shared.state.job = "train-hypernetwork"
487
+ shared.state.textinfo = "Initializing hypernetwork training..."
488
+ shared.state.job_count = steps
489
+
490
+ hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
491
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
492
+
493
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
494
+ unload = shared.opts.unload_models_when_training
495
+
496
+ if save_hypernetwork_every > 0:
497
+ hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
498
+ os.makedirs(hypernetwork_dir, exist_ok=True)
499
+ else:
500
+ hypernetwork_dir = None
501
+
502
+ if create_image_every > 0:
503
+ images_dir = os.path.join(log_directory, "images")
504
+ os.makedirs(images_dir, exist_ok=True)
505
+ else:
506
+ images_dir = None
507
+
508
+ checkpoint = sd_models.select_checkpoint()
509
+
510
+ initial_step = hypernetwork.step or 0
511
+ if initial_step >= steps:
512
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
513
+ return hypernetwork, filename
514
+
515
+ scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
516
+
517
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
518
+ if clip_grad:
519
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
520
+
521
+ if shared.opts.training_enable_tensorboard:
522
+ tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
523
+
524
+ # dataset loading may take a while, so input validations and early returns should be done before this
525
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
526
+
527
+ pin_memory = shared.opts.pin_memory
528
+
529
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
530
+
531
+ if shared.opts.save_training_settings_to_txt:
532
+ saved_params = dict(
533
+ model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
534
+ **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
535
+ )
536
+ logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
537
+
538
+ latent_sampling_method = ds.latent_sampling_method
539
+
540
+ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
541
+
542
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
543
+
544
+ if unload:
545
+ shared.parallel_processing_allowed = False
546
+ shared.sd_model.cond_stage_model.to(devices.cpu)
547
+ shared.sd_model.first_stage_model.to(devices.cpu)
548
+
549
+ weights = hypernetwork.weights()
550
+ hypernetwork.train()
551
+
552
+ # Here we use optimizer from saved HN, or we can specify as UI option.
553
+ if hypernetwork.optimizer_name in optimizer_dict:
554
+ optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
555
+ optimizer_name = hypernetwork.optimizer_name
556
+ else:
557
+ print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
558
+ optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
559
+ optimizer_name = 'AdamW'
560
+
561
+ if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
562
+ try:
563
+ optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
564
+ except RuntimeError as e:
565
+ print("Cannot resume from saved optimizer!")
566
+ print(e)
567
+
568
+ scaler = torch.cuda.amp.GradScaler()
569
+
570
+ batch_size = ds.batch_size
571
+ gradient_step = ds.gradient_step
572
+ # n steps = batch_size * gradient_step * n image processed
573
+ steps_per_epoch = len(ds) // batch_size // gradient_step
574
+ max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
575
+ loss_step = 0
576
+ _loss_step = 0 #internal
577
+ # size = len(ds.indexes)
578
+ # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
579
+ loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
580
+ # losses = torch.zeros((size,))
581
+ # previous_mean_losses = [0]
582
+ # previous_mean_loss = 0
583
+ # print("Mean loss of {} elements".format(size))
584
+
585
+ steps_without_grad = 0
586
+
587
+ last_saved_file = "<none>"
588
+ last_saved_image = "<none>"
589
+ forced_filename = "<none>"
590
+
591
+ pbar = tqdm.tqdm(total=steps - initial_step)
592
+ try:
593
+ sd_hijack_checkpoint.add()
594
+
595
+ for _ in range((steps-initial_step) * gradient_step):
596
+ if scheduler.finished:
597
+ break
598
+ if shared.state.interrupted:
599
+ break
600
+ for j, batch in enumerate(dl):
601
+ # works as a drop_last=True for gradient accumulation
602
+ if j == max_steps_per_epoch:
603
+ break
604
+ scheduler.apply(optimizer, hypernetwork.step)
605
+ if scheduler.finished:
606
+ break
607
+ if shared.state.interrupted:
608
+ break
609
+
610
+ if clip_grad:
611
+ clip_grad_sched.step(hypernetwork.step)
612
+
613
+ with devices.autocast():
614
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
615
+ if use_weight:
616
+ w = batch.weight.to(devices.device, non_blocking=pin_memory)
617
+ if tag_drop_out != 0 or shuffle_tags:
618
+ shared.sd_model.cond_stage_model.to(devices.device)
619
+ c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
620
+ shared.sd_model.cond_stage_model.to(devices.cpu)
621
+ else:
622
+ c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
623
+ if use_weight:
624
+ loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
625
+ del w
626
+ else:
627
+ loss = shared.sd_model.forward(x, c)[0] / gradient_step
628
+ del x
629
+ del c
630
+
631
+ _loss_step += loss.item()
632
+ scaler.scale(loss).backward()
633
+
634
+ # go back until we reach gradient accumulation steps
635
+ if (j + 1) % gradient_step != 0:
636
+ continue
637
+ loss_logging.append(_loss_step)
638
+ if clip_grad:
639
+ clip_grad(weights, clip_grad_sched.learn_rate)
640
+
641
+ scaler.step(optimizer)
642
+ scaler.update()
643
+ hypernetwork.step += 1
644
+ pbar.update()
645
+ optimizer.zero_grad(set_to_none=True)
646
+ loss_step = _loss_step
647
+ _loss_step = 0
648
+
649
+ steps_done = hypernetwork.step + 1
650
+
651
+ epoch_num = hypernetwork.step // steps_per_epoch
652
+ epoch_step = hypernetwork.step % steps_per_epoch
653
+
654
+ description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
655
+ pbar.set_description(description)
656
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
657
+ # Before saving, change name to match current checkpoint.
658
+ hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
659
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
660
+ hypernetwork.optimizer_name = optimizer_name
661
+ if shared.opts.save_optimizer_state:
662
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
663
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
664
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
665
+
666
+
667
+
668
+ if shared.opts.training_enable_tensorboard:
669
+ epoch_num = hypernetwork.step // len(ds)
670
+ epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
671
+ mean_loss = sum(loss_logging) / len(loss_logging)
672
+ textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
673
+
674
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
675
+ "loss": f"{loss_step:.7f}",
676
+ "learn_rate": scheduler.learn_rate
677
+ })
678
+
679
+ if images_dir is not None and steps_done % create_image_every == 0:
680
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
681
+ last_saved_image = os.path.join(images_dir, forced_filename)
682
+ hypernetwork.eval()
683
+ rng_state = torch.get_rng_state()
684
+ cuda_rng_state = None
685
+ if torch.cuda.is_available():
686
+ cuda_rng_state = torch.cuda.get_rng_state_all()
687
+ shared.sd_model.cond_stage_model.to(devices.device)
688
+ shared.sd_model.first_stage_model.to(devices.device)
689
+
690
+ p = processing.StableDiffusionProcessingTxt2Img(
691
+ sd_model=shared.sd_model,
692
+ do_not_save_grid=True,
693
+ do_not_save_samples=True,
694
+ )
695
+
696
+ p.disable_extra_networks = True
697
+
698
+ if preview_from_txt2img:
699
+ p.prompt = preview_prompt
700
+ p.negative_prompt = preview_negative_prompt
701
+ p.steps = preview_steps
702
+ p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
703
+ p.cfg_scale = preview_cfg_scale
704
+ p.seed = preview_seed
705
+ p.width = preview_width
706
+ p.height = preview_height
707
+ else:
708
+ p.prompt = batch.cond_text[0]
709
+ p.steps = 20
710
+ p.width = training_width
711
+ p.height = training_height
712
+
713
+ preview_text = p.prompt
714
+
715
+ with closing(p):
716
+ processed = processing.process_images(p)
717
+ image = processed.images[0] if len(processed.images) > 0 else None
718
+
719
+ if unload:
720
+ shared.sd_model.cond_stage_model.to(devices.cpu)
721
+ shared.sd_model.first_stage_model.to(devices.cpu)
722
+ torch.set_rng_state(rng_state)
723
+ if torch.cuda.is_available():
724
+ torch.cuda.set_rng_state_all(cuda_rng_state)
725
+ hypernetwork.train()
726
+ if image is not None:
727
+ shared.state.assign_current_image(image)
728
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
729
+ textual_inversion.tensorboard_add_image(tensorboard_writer,
730
+ f"Validation at epoch {epoch_num}", image,
731
+ hypernetwork.step)
732
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
733
+ last_saved_image += f", prompt: {preview_text}"
734
+
735
+ shared.state.job_no = hypernetwork.step
736
+
737
+ shared.state.textinfo = f"""
738
+ <p>
739
+ Loss: {loss_step:.7f}<br/>
740
+ Step: {steps_done}<br/>
741
+ Last prompt: {html.escape(batch.cond_text[0])}<br/>
742
+ Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
743
+ Last saved image: {html.escape(last_saved_image)}<br/>
744
+ </p>
745
+ """
746
+ except Exception:
747
+ errors.report("Exception in training hypernetwork", exc_info=True)
748
+ finally:
749
+ pbar.leave = False
750
+ pbar.close()
751
+ hypernetwork.eval()
752
+ sd_hijack_checkpoint.remove()
753
+
754
+
755
+
756
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
757
+ hypernetwork.optimizer_name = optimizer_name
758
+ if shared.opts.save_optimizer_state:
759
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
760
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
761
+
762
+ del optimizer
763
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
764
+ shared.sd_model.cond_stage_model.to(devices.device)
765
+ shared.sd_model.first_stage_model.to(devices.device)
766
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
767
+
768
+ return hypernetwork, filename
769
+
770
+ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
771
+ old_hypernetwork_name = hypernetwork.name
772
+ old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
773
+ old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
774
+ try:
775
+ hypernetwork.sd_checkpoint = checkpoint.shorthash
776
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
777
+ hypernetwork.name = hypernetwork_name
778
+ hypernetwork.save(filename)
779
+ except:
780
+ hypernetwork.sd_checkpoint = old_sd_checkpoint
781
+ hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
782
+ hypernetwork.name = old_hypernetwork_name
783
+ raise
modules/hypernetworks/ui.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+
3
+ import gradio as gr
4
+ import modules.hypernetworks.hypernetwork
5
+ from modules import devices, sd_hijack, shared
6
+
7
+ not_available = ["hardswish", "multiheadattention"]
8
+ keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
9
+
10
+
11
+ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
12
+ filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
13
+
14
+ return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
15
+
16
+
17
+ def train_hypernetwork(*args):
18
+ shared.loaded_hypernetworks = []
19
+
20
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
21
+
22
+ try:
23
+ sd_hijack.undo_optimizations()
24
+
25
+ hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
26
+
27
+ res = f"""
28
+ Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
29
+ Hypernetwork saved to {html.escape(filename)}
30
+ """
31
+ return res, ""
32
+ except Exception:
33
+ raise
34
+ finally:
35
+ shared.sd_model.cond_stage_model.to(devices.device)
36
+ shared.sd_model.first_stage_model.to(devices.device)
37
+ sd_hijack.apply_optimizations()
38
+
modules/images.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+
5
+ import pytz
6
+ import io
7
+ import math
8
+ import os
9
+ from collections import namedtuple
10
+ import re
11
+
12
+ import numpy as np
13
+ import piexif
14
+ import piexif.helper
15
+ from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
16
+ import string
17
+ import json
18
+ import hashlib
19
+
20
+ from modules import sd_samplers, shared, script_callbacks, errors
21
+ from modules.paths_internal import roboto_ttf_file
22
+ from modules.shared import opts
23
+
24
+ import modules.sd_vae as sd_vae
25
+
26
+ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
27
+
28
+
29
+ def get_font(fontsize: int):
30
+ try:
31
+ return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
32
+ except Exception:
33
+ return ImageFont.truetype(roboto_ttf_file, fontsize)
34
+
35
+
36
+ def image_grid(imgs, batch_size=1, rows=None):
37
+ if rows is None:
38
+ if opts.n_rows > 0:
39
+ rows = opts.n_rows
40
+ elif opts.n_rows == 0:
41
+ rows = batch_size
42
+ elif opts.grid_prevent_empty_spots:
43
+ rows = math.floor(math.sqrt(len(imgs)))
44
+ while len(imgs) % rows != 0:
45
+ rows -= 1
46
+ else:
47
+ rows = math.sqrt(len(imgs))
48
+ rows = round(rows)
49
+ if rows > len(imgs):
50
+ rows = len(imgs)
51
+
52
+ cols = math.ceil(len(imgs) / rows)
53
+
54
+ params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
55
+ script_callbacks.image_grid_callback(params)
56
+
57
+ w, h = imgs[0].size
58
+ grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
59
+
60
+ for i, img in enumerate(params.imgs):
61
+ grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
62
+
63
+ return grid
64
+
65
+
66
+ Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
67
+
68
+
69
+ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
70
+ w = image.width
71
+ h = image.height
72
+
73
+ non_overlap_width = tile_w - overlap
74
+ non_overlap_height = tile_h - overlap
75
+
76
+ cols = math.ceil((w - overlap) / non_overlap_width)
77
+ rows = math.ceil((h - overlap) / non_overlap_height)
78
+
79
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
80
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
81
+
82
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
83
+ for row in range(rows):
84
+ row_images = []
85
+
86
+ y = int(row * dy)
87
+
88
+ if y + tile_h >= h:
89
+ y = h - tile_h
90
+
91
+ for col in range(cols):
92
+ x = int(col * dx)
93
+
94
+ if x + tile_w >= w:
95
+ x = w - tile_w
96
+
97
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
98
+
99
+ row_images.append([x, tile_w, tile])
100
+
101
+ grid.tiles.append([y, tile_h, row_images])
102
+
103
+ return grid
104
+
105
+
106
+ def combine_grid(grid):
107
+ def make_mask_image(r):
108
+ r = r * 255 / grid.overlap
109
+ r = r.astype(np.uint8)
110
+ return Image.fromarray(r, 'L')
111
+
112
+ mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
113
+ mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
114
+
115
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
116
+ for y, h, row in grid.tiles:
117
+ combined_row = Image.new("RGB", (grid.image_w, h))
118
+ for x, w, tile in row:
119
+ if x == 0:
120
+ combined_row.paste(tile, (0, 0))
121
+ continue
122
+
123
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
124
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
125
+
126
+ if y == 0:
127
+ combined_image.paste(combined_row, (0, 0))
128
+ continue
129
+
130
+ combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
131
+ combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
132
+
133
+ return combined_image
134
+
135
+
136
+ class GridAnnotation:
137
+ def __init__(self, text='', is_active=True):
138
+ self.text = text
139
+ self.is_active = is_active
140
+ self.size = None
141
+
142
+
143
+ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
144
+
145
+ color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
146
+ color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
147
+ color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
148
+
149
+ def wrap(drawing, text, font, line_length):
150
+ lines = ['']
151
+ for word in text.split():
152
+ line = f'{lines[-1]} {word}'.strip()
153
+ if drawing.textlength(line, font=font) <= line_length:
154
+ lines[-1] = line
155
+ else:
156
+ lines.append(word)
157
+ return lines
158
+
159
+ def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
160
+ for line in lines:
161
+ fnt = initial_fnt
162
+ fontsize = initial_fontsize
163
+ while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
164
+ fontsize -= 1
165
+ fnt = get_font(fontsize)
166
+ drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
167
+
168
+ if not line.is_active:
169
+ drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
170
+
171
+ draw_y += line.size[1] + line_spacing
172
+
173
+ fontsize = (width + height) // 25
174
+ line_spacing = fontsize // 2
175
+
176
+ fnt = get_font(fontsize)
177
+
178
+ pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
179
+
180
+ cols = im.width // width
181
+ rows = im.height // height
182
+
183
+ assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
184
+ assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
185
+
186
+ calc_img = Image.new("RGB", (1, 1), color_background)
187
+ calc_d = ImageDraw.Draw(calc_img)
188
+
189
+ for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
190
+ items = [] + texts
191
+ texts.clear()
192
+
193
+ for line in items:
194
+ wrapped = wrap(calc_d, line.text, fnt, allowed_width)
195
+ texts += [GridAnnotation(x, line.is_active) for x in wrapped]
196
+
197
+ for line in texts:
198
+ bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
199
+ line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
200
+ line.allowed_width = allowed_width
201
+
202
+ hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
203
+ ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
204
+
205
+ pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
206
+
207
+ result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
208
+
209
+ for row in range(rows):
210
+ for col in range(cols):
211
+ cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
212
+ result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
213
+
214
+ d = ImageDraw.Draw(result)
215
+
216
+ for col in range(cols):
217
+ x = pad_left + (width + margin) * col + width / 2
218
+ y = pad_top / 2 - hor_text_heights[col] / 2
219
+
220
+ draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
221
+
222
+ for row in range(rows):
223
+ x = pad_left / 2
224
+ y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
225
+
226
+ draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
227
+
228
+ return result
229
+
230
+
231
+ def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
232
+ prompts = all_prompts[1:]
233
+ boundary = math.ceil(len(prompts) / 2)
234
+
235
+ prompts_horiz = prompts[:boundary]
236
+ prompts_vert = prompts[boundary:]
237
+
238
+ hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
239
+ ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
240
+
241
+ return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
242
+
243
+
244
+ def resize_image(resize_mode, im, width, height, upscaler_name=None):
245
+ """
246
+ Resizes an image with the specified resize_mode, width, and height.
247
+
248
+ Args:
249
+ resize_mode: The mode to use when resizing the image.
250
+ 0: Resize the image to the specified width and height.
251
+ 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
252
+ 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
253
+ im: The image to resize.
254
+ width: The width to resize the image to.
255
+ height: The height to resize the image to.
256
+ upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
257
+ """
258
+
259
+ upscaler_name = upscaler_name or opts.upscaler_for_img2img
260
+
261
+ def resize(im, w, h):
262
+ if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
263
+ return im.resize((w, h), resample=LANCZOS)
264
+
265
+ scale = max(w / im.width, h / im.height)
266
+
267
+ if scale > 1.0:
268
+ upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
269
+ if len(upscalers) == 0:
270
+ upscaler = shared.sd_upscalers[0]
271
+ print(f"could not find upscaler named {upscaler_name or '<empty string>'}, using {upscaler.name} as a fallback")
272
+ else:
273
+ upscaler = upscalers[0]
274
+
275
+ im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
276
+
277
+ if im.width != w or im.height != h:
278
+ im = im.resize((w, h), resample=LANCZOS)
279
+
280
+ return im
281
+
282
+ if resize_mode == 0:
283
+ res = resize(im, width, height)
284
+
285
+ elif resize_mode == 1:
286
+ ratio = width / height
287
+ src_ratio = im.width / im.height
288
+
289
+ src_w = width if ratio > src_ratio else im.width * height // im.height
290
+ src_h = height if ratio <= src_ratio else im.height * width // im.width
291
+
292
+ resized = resize(im, src_w, src_h)
293
+ res = Image.new("RGB", (width, height))
294
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
295
+
296
+ else:
297
+ ratio = width / height
298
+ src_ratio = im.width / im.height
299
+
300
+ src_w = width if ratio < src_ratio else im.width * height // im.height
301
+ src_h = height if ratio >= src_ratio else im.height * width // im.width
302
+
303
+ resized = resize(im, src_w, src_h)
304
+ res = Image.new("RGB", (width, height))
305
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
306
+
307
+ if ratio < src_ratio:
308
+ fill_height = height // 2 - src_h // 2
309
+ if fill_height > 0:
310
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
311
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
312
+ elif ratio > src_ratio:
313
+ fill_width = width // 2 - src_w // 2
314
+ if fill_width > 0:
315
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
316
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
317
+
318
+ return res
319
+
320
+
321
+ invalid_filename_chars = '<>:"/\\|?*\n'
322
+ invalid_filename_prefix = ' '
323
+ invalid_filename_postfix = ' .'
324
+ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
325
+ re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
326
+ re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
327
+ max_filename_part_length = 128
328
+ NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
329
+
330
+
331
+ def sanitize_filename_part(text, replace_spaces=True):
332
+ if text is None:
333
+ return None
334
+
335
+ if replace_spaces:
336
+ text = text.replace(' ', '_')
337
+
338
+ text = text.translate({ord(x): '_' for x in invalid_filename_chars})
339
+ text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
340
+ text = text.rstrip(invalid_filename_postfix)
341
+ return text
342
+
343
+
344
+ class FilenameGenerator:
345
+ def get_vae_filename(self): #get the name of the VAE file.
346
+ if sd_vae.loaded_vae_file is None:
347
+ return "NoneType"
348
+ file_name = os.path.basename(sd_vae.loaded_vae_file)
349
+ split_file_name = file_name.split('.')
350
+ if len(split_file_name) > 1 and split_file_name[0] == '':
351
+ return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
352
+ else:
353
+ return split_file_name[0]
354
+
355
+ replacements = {
356
+ 'seed': lambda self: self.seed if self.seed is not None else '',
357
+ 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
358
+ 'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
359
+ 'steps': lambda self: self.p and self.p.steps,
360
+ 'cfg': lambda self: self.p and self.p.cfg_scale,
361
+ 'width': lambda self: self.image.width,
362
+ 'height': lambda self: self.image.height,
363
+ 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
364
+ 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
365
+ 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
366
+ 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False),
367
+ 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
368
+ 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
369
+ 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
370
+ 'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
371
+ 'prompt': lambda self: sanitize_filename_part(self.prompt),
372
+ 'prompt_no_styles': lambda self: self.prompt_no_style(),
373
+ 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
374
+ 'prompt_words': lambda self: self.prompt_words(),
375
+ 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
376
+ 'batch_size': lambda self: self.p.batch_size,
377
+ 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
378
+ 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
379
+ 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
380
+ 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
381
+ 'user': lambda self: self.p.user,
382
+ 'vae_filename': lambda self: self.get_vae_filename(),
383
+ 'none': lambda self: '', # Overrides the default so you can get just the sequence number
384
+ }
385
+ default_time_format = '%Y%m%d%H%M%S'
386
+
387
+ def __init__(self, p, seed, prompt, image, zip=False):
388
+ self.p = p
389
+ self.seed = seed
390
+ self.prompt = prompt
391
+ self.image = image
392
+ self.zip = zip
393
+
394
+ def hasprompt(self, *args):
395
+ lower = self.prompt.lower()
396
+ if self.p is None or self.prompt is None:
397
+ return None
398
+ outres = ""
399
+ for arg in args:
400
+ if arg != "":
401
+ division = arg.split("|")
402
+ expected = division[0].lower()
403
+ default = division[1] if len(division) > 1 else ""
404
+ if lower.find(expected) >= 0:
405
+ outres = f'{outres}{expected}'
406
+ else:
407
+ outres = outres if default == "" else f'{outres}{default}'
408
+ return sanitize_filename_part(outres)
409
+
410
+ def prompt_no_style(self):
411
+ if self.p is None or self.prompt is None:
412
+ return None
413
+
414
+ prompt_no_style = self.prompt
415
+ for style in shared.prompt_styles.get_style_prompts(self.p.styles):
416
+ if style:
417
+ for part in style.split("{prompt}"):
418
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
419
+
420
+ prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
421
+
422
+ return sanitize_filename_part(prompt_no_style, replace_spaces=False)
423
+
424
+ def prompt_words(self):
425
+ words = [x for x in re_nonletters.split(self.prompt or "") if x]
426
+ if len(words) == 0:
427
+ words = ["empty"]
428
+ return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
429
+
430
+ def datetime(self, *args):
431
+ time_datetime = datetime.datetime.now()
432
+
433
+ time_format = args[0] if (args and args[0] != "") else self.default_time_format
434
+ try:
435
+ time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
436
+ except pytz.exceptions.UnknownTimeZoneError:
437
+ time_zone = None
438
+
439
+ time_zone_time = time_datetime.astimezone(time_zone)
440
+ try:
441
+ formatted_time = time_zone_time.strftime(time_format)
442
+ except (ValueError, TypeError):
443
+ formatted_time = time_zone_time.strftime(self.default_time_format)
444
+
445
+ return sanitize_filename_part(formatted_time, replace_spaces=False)
446
+
447
+ def apply(self, x):
448
+ res = ''
449
+
450
+ for m in re_pattern.finditer(x):
451
+ text, pattern = m.groups()
452
+
453
+ if pattern is None:
454
+ res += text
455
+ continue
456
+
457
+ pattern_args = []
458
+ while True:
459
+ m = re_pattern_arg.match(pattern)
460
+ if m is None:
461
+ break
462
+
463
+ pattern, arg = m.groups()
464
+ pattern_args.insert(0, arg)
465
+
466
+ fun = self.replacements.get(pattern.lower())
467
+ if fun is not None:
468
+ try:
469
+ replacement = fun(self, *pattern_args)
470
+ except Exception:
471
+ replacement = None
472
+ errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
473
+
474
+ if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
475
+ continue
476
+ elif replacement is not None:
477
+ res += text + str(replacement)
478
+ continue
479
+
480
+ res += f'{text}[{pattern}]'
481
+
482
+ return res
483
+
484
+
485
+ def get_next_sequence_number(path, basename):
486
+ """
487
+ Determines and returns the next sequence number to use when saving an image in the specified directory.
488
+
489
+ The sequence starts at 0.
490
+ """
491
+ result = -1
492
+ if basename != '':
493
+ basename = f"{basename}-"
494
+
495
+ prefix_length = len(basename)
496
+ for p in os.listdir(path):
497
+ if p.startswith(basename):
498
+ parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
499
+ try:
500
+ result = max(int(parts[0]), result)
501
+ except ValueError:
502
+ pass
503
+
504
+ return result + 1
505
+
506
+
507
+ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
508
+ """
509
+ Saves image to filename, including geninfo as text information for generation info.
510
+ For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
511
+ For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
512
+ """
513
+
514
+ if extension is None:
515
+ extension = os.path.splitext(filename)[1]
516
+
517
+ image_format = Image.registered_extensions()[extension]
518
+
519
+ if extension.lower() == '.png':
520
+ existing_pnginfo = existing_pnginfo or {}
521
+ if opts.enable_pnginfo:
522
+ existing_pnginfo[pnginfo_section_name] = geninfo
523
+
524
+ if opts.enable_pnginfo:
525
+ pnginfo_data = PngImagePlugin.PngInfo()
526
+ for k, v in (existing_pnginfo or {}).items():
527
+ pnginfo_data.add_text(k, str(v))
528
+ else:
529
+ pnginfo_data = None
530
+
531
+ image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
532
+
533
+ elif extension.lower() in (".jpg", ".jpeg", ".webp"):
534
+ if image.mode == 'RGBA':
535
+ image = image.convert("RGB")
536
+ elif image.mode == 'I;16':
537
+ image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
538
+
539
+ image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
540
+
541
+ if opts.enable_pnginfo and geninfo is not None:
542
+ exif_bytes = piexif.dump({
543
+ "Exif": {
544
+ piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
545
+ },
546
+ })
547
+
548
+ piexif.insert(exif_bytes, filename)
549
+ else:
550
+ image.save(filename, format=image_format, quality=opts.jpeg_quality)
551
+
552
+
553
+ def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
554
+ """Save an image.
555
+
556
+ Args:
557
+ image (`PIL.Image`):
558
+ The image to be saved.
559
+ path (`str`):
560
+ The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
561
+ basename (`str`):
562
+ The base filename which will be applied to `filename pattern`.
563
+ seed, prompt, short_filename,
564
+ extension (`str`):
565
+ Image file extension, default is `png`.
566
+ pngsectionname (`str`):
567
+ Specify the name of the section which `info` will be saved in.
568
+ info (`str` or `PngImagePlugin.iTXt`):
569
+ PNG info chunks.
570
+ existing_info (`dict`):
571
+ Additional PNG info. `existing_info == {pngsectionname: info, ...}`
572
+ no_prompt:
573
+ TODO I don't know its meaning.
574
+ p (`StableDiffusionProcessing`)
575
+ forced_filename (`str`):
576
+ If specified, `basename` and filename pattern will be ignored.
577
+ save_to_dirs (bool):
578
+ If true, the image will be saved into a subdirectory of `path`.
579
+
580
+ Returns: (fullfn, txt_fullfn)
581
+ fullfn (`str`):
582
+ The full path of the saved imaged.
583
+ txt_fullfn (`str` or None):
584
+ If a text file is saved for this image, this will be its full path. Otherwise None.
585
+ """
586
+ namegen = FilenameGenerator(p, seed, prompt, image)
587
+
588
+ if save_to_dirs is None:
589
+ save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
590
+
591
+ if save_to_dirs:
592
+ dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
593
+ path = os.path.join(path, dirname)
594
+
595
+ os.makedirs(path, exist_ok=True)
596
+
597
+ if forced_filename is None:
598
+ if short_filename or seed is None:
599
+ file_decoration = ""
600
+ elif opts.save_to_dirs:
601
+ file_decoration = opts.samples_filename_pattern or "[seed]"
602
+ else:
603
+ file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
604
+
605
+ file_decoration = namegen.apply(file_decoration) + suffix
606
+
607
+ add_number = opts.save_images_add_number or file_decoration == ''
608
+
609
+ if file_decoration != "" and add_number:
610
+ file_decoration = f"-{file_decoration}"
611
+
612
+ if add_number:
613
+ basecount = get_next_sequence_number(path, basename)
614
+ fullfn = None
615
+ for i in range(500):
616
+ fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
617
+ fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
618
+ if not os.path.exists(fullfn):
619
+ break
620
+ else:
621
+ fullfn = os.path.join(path, f"{file_decoration}.{extension}")
622
+ else:
623
+ fullfn = os.path.join(path, f"{forced_filename}.{extension}")
624
+
625
+ pnginfo = existing_info or {}
626
+ if info is not None:
627
+ pnginfo[pnginfo_section_name] = info
628
+
629
+ params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
630
+ script_callbacks.before_image_saved_callback(params)
631
+
632
+ image = params.image
633
+ fullfn = params.filename
634
+ info = params.pnginfo.get(pnginfo_section_name, None)
635
+
636
+ def _atomically_save_image(image_to_save, filename_without_extension, extension):
637
+ """
638
+ save image with .tmp extension to avoid race condition when another process detects new image in the directory
639
+ """
640
+ temp_file_path = f"{filename_without_extension}.tmp"
641
+
642
+ save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
643
+
644
+ os.replace(temp_file_path, filename_without_extension + extension)
645
+
646
+ fullfn_without_extension, extension = os.path.splitext(params.filename)
647
+ if hasattr(os, 'statvfs'):
648
+ max_name_len = os.statvfs(path).f_namemax
649
+ fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
650
+ params.filename = fullfn_without_extension + extension
651
+ fullfn = params.filename
652
+ _atomically_save_image(image, fullfn_without_extension, extension)
653
+
654
+ image.already_saved_as = fullfn
655
+
656
+ oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
657
+ if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
658
+ ratio = image.width / image.height
659
+ resize_to = None
660
+ if oversize and ratio > 1:
661
+ resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
662
+ elif oversize:
663
+ resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
664
+
665
+ if resize_to is not None:
666
+ try:
667
+ # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
668
+ image = image.resize(resize_to, LANCZOS)
669
+ except Exception:
670
+ image = image.resize(resize_to)
671
+ try:
672
+ _atomically_save_image(image, fullfn_without_extension, ".jpg")
673
+ except Exception as e:
674
+ errors.display(e, "saving image as downscaled JPG")
675
+
676
+ if opts.save_txt and info is not None:
677
+ txt_fullfn = f"{fullfn_without_extension}.txt"
678
+ with open(txt_fullfn, "w", encoding="utf8") as file:
679
+ file.write(f"{info}\n")
680
+ else:
681
+ txt_fullfn = None
682
+
683
+ script_callbacks.image_saved_callback(params)
684
+
685
+ return fullfn, txt_fullfn
686
+
687
+
688
+ IGNORED_INFO_KEYS = {
689
+ 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
690
+ 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
691
+ 'icc_profile', 'chromaticity', 'photoshop',
692
+ }
693
+
694
+
695
+ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
696
+ items = (image.info or {}).copy()
697
+
698
+ geninfo = items.pop('parameters', None)
699
+
700
+ if "exif" in items:
701
+ exif = piexif.load(items["exif"])
702
+ exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
703
+ try:
704
+ exif_comment = piexif.helper.UserComment.load(exif_comment)
705
+ except ValueError:
706
+ exif_comment = exif_comment.decode('utf8', errors="ignore")
707
+
708
+ if exif_comment:
709
+ items['exif comment'] = exif_comment
710
+ geninfo = exif_comment
711
+
712
+ for field in IGNORED_INFO_KEYS:
713
+ items.pop(field, None)
714
+
715
+ if items.get("Software", None) == "NovelAI":
716
+ try:
717
+ json_info = json.loads(items["Comment"])
718
+ sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
719
+
720
+ geninfo = f"""{items["Description"]}
721
+ Negative prompt: {json_info["uc"]}
722
+ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
723
+ except Exception:
724
+ errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
725
+
726
+ return geninfo, items
727
+
728
+
729
+ def image_data(data):
730
+ import gradio as gr
731
+
732
+ try:
733
+ image = Image.open(io.BytesIO(data))
734
+ textinfo, _ = read_info_from_image(image)
735
+ return textinfo, None
736
+ except Exception:
737
+ pass
738
+
739
+ try:
740
+ text = data.decode('utf8')
741
+ assert len(text) < 10000
742
+ return text, None
743
+
744
+ except Exception:
745
+ pass
746
+
747
+ return gr.update(), None
748
+
749
+
750
+ def flatten(img, bgcolor):
751
+ """replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
752
+
753
+ if img.mode == "RGBA":
754
+ background = Image.new('RGBA', img.size, bgcolor)
755
+ background.paste(img, mask=img)
756
+ img = background
757
+
758
+ return img.convert('RGB')
modules/img2img.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import closing
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
7
+ import gradio as gr
8
+
9
+ from modules import sd_samplers, images as imgutil
10
+ from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
11
+ from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
12
+ from modules.shared import opts, state
13
+ from modules.images import save_image
14
+ import modules.shared as shared
15
+ import modules.processing as processing
16
+ from modules.ui import plaintext_to_html
17
+ import modules.scripts
18
+
19
+
20
+ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
21
+ processing.fix_seed(p)
22
+
23
+ images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
24
+
25
+ is_inpaint_batch = False
26
+ if inpaint_mask_dir:
27
+ inpaint_masks = shared.listfiles(inpaint_mask_dir)
28
+ is_inpaint_batch = bool(inpaint_masks)
29
+
30
+ if is_inpaint_batch:
31
+ print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
32
+
33
+ print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
34
+
35
+ save_normally = output_dir == ''
36
+
37
+ p.do_not_save_grid = True
38
+ p.do_not_save_samples = not save_normally
39
+
40
+ state.job_count = len(images) * p.n_iter
41
+
42
+ # extract "default" params to use in case getting png info fails
43
+ prompt = p.prompt
44
+ negative_prompt = p.negative_prompt
45
+ seed = p.seed
46
+ cfg_scale = p.cfg_scale
47
+ sampler_name = p.sampler_name
48
+ steps = p.steps
49
+
50
+ for i, image in enumerate(images):
51
+ state.job = f"{i+1} out of {len(images)}"
52
+ if state.skipped:
53
+ state.skipped = False
54
+
55
+ if state.interrupted:
56
+ break
57
+
58
+ p.filename = os.path.basename(image)
59
+
60
+ try:
61
+ img = Image.open(image)
62
+ except UnidentifiedImageError as e:
63
+ print(e)
64
+ continue
65
+ # Use the EXIF orientation of photos taken by smartphones.
66
+ img = ImageOps.exif_transpose(img)
67
+
68
+ if to_scale:
69
+ p.width = int(img.width * scale_by)
70
+ p.height = int(img.height * scale_by)
71
+
72
+ p.init_images = [img] * p.batch_size
73
+
74
+ image_path = Path(image)
75
+ if is_inpaint_batch:
76
+ # try to find corresponding mask for an image using simple filename matching
77
+ if len(inpaint_masks) == 1:
78
+ mask_image_path = inpaint_masks[0]
79
+ else:
80
+ # try to find corresponding mask for an image using simple filename matching
81
+ mask_image_dir = Path(inpaint_mask_dir)
82
+ masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
83
+
84
+ if len(masks_found) == 0:
85
+ print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
86
+ continue
87
+
88
+ # it should contain only 1 matching mask
89
+ # otherwise user has many masks with the same name but different extensions
90
+ mask_image_path = masks_found[0]
91
+
92
+ mask_image = Image.open(mask_image_path)
93
+ p.image_mask = mask_image
94
+
95
+ if use_png_info:
96
+ try:
97
+ info_img = img
98
+ if png_info_dir:
99
+ info_img_path = os.path.join(png_info_dir, os.path.basename(image))
100
+ info_img = Image.open(info_img_path)
101
+ geninfo, _ = imgutil.read_info_from_image(info_img)
102
+ parsed_parameters = parse_generation_parameters(geninfo)
103
+ parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
104
+ except Exception:
105
+ parsed_parameters = {}
106
+
107
+ p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
108
+ p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
109
+ p.seed = int(parsed_parameters.get("Seed", seed))
110
+ p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
111
+ p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
112
+ p.steps = int(parsed_parameters.get("Steps", steps))
113
+
114
+ proc = modules.scripts.scripts_img2img.run(p, *args)
115
+ if proc is None:
116
+ proc = process_images(p)
117
+
118
+ for n, processed_image in enumerate(proc.images):
119
+ filename = image_path.stem
120
+ infotext = proc.infotext(p, n)
121
+ relpath = os.path.dirname(os.path.relpath(image, input_dir))
122
+
123
+ if n > 0:
124
+ filename += f"-{n}"
125
+
126
+ if not save_normally:
127
+ os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
128
+ if processed_image.mode == 'RGBA':
129
+ processed_image = processed_image.convert("RGB")
130
+ save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
131
+
132
+
133
+ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
134
+ override_settings = create_override_settings_dict(override_settings_texts)
135
+
136
+ is_batch = mode == 5
137
+
138
+ if mode == 0: # img2img
139
+ image = init_img.convert("RGB")
140
+ mask = None
141
+ elif mode == 1: # img2img sketch
142
+ image = sketch.convert("RGB")
143
+ mask = None
144
+ elif mode == 2: # inpaint
145
+ image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
146
+ alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
147
+ mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
148
+ mask = ImageChops.lighter(alpha_mask, mask).convert('L')
149
+ image = image.convert("RGB")
150
+ elif mode == 3: # inpaint sketch
151
+ image = inpaint_color_sketch
152
+ orig = inpaint_color_sketch_orig or inpaint_color_sketch
153
+ pred = np.any(np.array(image) != np.array(orig), axis=-1)
154
+ mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
155
+ mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
156
+ blur = ImageFilter.GaussianBlur(mask_blur)
157
+ image = Image.composite(image.filter(blur), orig, mask.filter(blur))
158
+ image = image.convert("RGB")
159
+ elif mode == 4: # inpaint upload mask
160
+ image = init_img_inpaint
161
+ mask = init_mask_inpaint
162
+ else:
163
+ image = None
164
+ mask = None
165
+
166
+ # Use the EXIF orientation of photos taken by smartphones.
167
+ if image is not None:
168
+ image = ImageOps.exif_transpose(image)
169
+
170
+ if selected_scale_tab == 1 and not is_batch:
171
+ assert image, "Can't scale by because no image is selected"
172
+
173
+ width = int(image.width * scale_by)
174
+ height = int(image.height * scale_by)
175
+
176
+ assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
177
+
178
+ p = StableDiffusionProcessingImg2Img(
179
+ sd_model=shared.sd_model,
180
+ outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
181
+ outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
182
+ prompt=prompt,
183
+ negative_prompt=negative_prompt,
184
+ styles=prompt_styles,
185
+ seed=seed,
186
+ subseed=subseed,
187
+ subseed_strength=subseed_strength,
188
+ seed_resize_from_h=seed_resize_from_h,
189
+ seed_resize_from_w=seed_resize_from_w,
190
+ seed_enable_extras=seed_enable_extras,
191
+ sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
192
+ batch_size=batch_size,
193
+ n_iter=n_iter,
194
+ steps=steps,
195
+ cfg_scale=cfg_scale,
196
+ width=width,
197
+ height=height,
198
+ restore_faces=restore_faces,
199
+ tiling=tiling,
200
+ init_images=[image],
201
+ mask=mask,
202
+ mask_blur=mask_blur,
203
+ inpainting_fill=inpainting_fill,
204
+ resize_mode=resize_mode,
205
+ denoising_strength=denoising_strength,
206
+ image_cfg_scale=image_cfg_scale,
207
+ inpaint_full_res=inpaint_full_res,
208
+ inpaint_full_res_padding=inpaint_full_res_padding,
209
+ inpainting_mask_invert=inpainting_mask_invert,
210
+ override_settings=override_settings,
211
+ )
212
+
213
+ p.scripts = modules.scripts.scripts_img2img
214
+ p.script_args = args
215
+
216
+ p.user = request.username
217
+
218
+ if shared.cmd_opts.enable_console_prompts:
219
+ print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
220
+
221
+ if mask:
222
+ p.extra_generation_params["Mask blur"] = mask_blur
223
+
224
+ with closing(p):
225
+ if is_batch:
226
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
227
+
228
+ process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
229
+
230
+ processed = Processed(p, [], p.seed, "")
231
+ else:
232
+ processed = modules.scripts.scripts_img2img.run(p, *args)
233
+ if processed is None:
234
+ processed = process_images(p)
235
+
236
+ shared.total_tqdm.clear()
237
+
238
+ generation_info_js = processed.js()
239
+ if opts.samples_log_stdout:
240
+ print(generation_info_js)
241
+
242
+ if opts.do_not_show_images:
243
+ processed.images = []
244
+
245
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
modules/import_hook.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import sys
2
+
3
+ # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
4
+ if "--xformers" not in "".join(sys.argv):
5
+ sys.modules["xformers"] = None
modules/interrogate.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from collections import namedtuple
4
+ from pathlib import Path
5
+ import re
6
+
7
+ import torch
8
+ import torch.hub
9
+
10
+ from torchvision import transforms
11
+ from torchvision.transforms.functional import InterpolationMode
12
+
13
+ from modules import devices, paths, shared, lowvram, modelloader, errors
14
+
15
+ blip_image_eval_size = 384
16
+ clip_model_name = 'ViT-L/14'
17
+
18
+ Category = namedtuple("Category", ["name", "topn", "items"])
19
+
20
+ re_topn = re.compile(r"\.top(\d+)\.")
21
+
22
+ def category_types():
23
+ return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
24
+
25
+
26
+ def download_default_clip_interrogate_categories(content_dir):
27
+ print("Downloading CLIP categories...")
28
+
29
+ tmpdir = f"{content_dir}_tmp"
30
+ category_types = ["artists", "flavors", "mediums", "movements"]
31
+
32
+ try:
33
+ os.makedirs(tmpdir, exist_ok=True)
34
+ for category_type in category_types:
35
+ torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
36
+ os.rename(tmpdir, content_dir)
37
+
38
+ except Exception as e:
39
+ errors.display(e, "downloading default CLIP interrogate categories")
40
+ finally:
41
+ if os.path.exists(tmpdir):
42
+ os.removedirs(tmpdir)
43
+
44
+
45
+ class InterrogateModels:
46
+ blip_model = None
47
+ clip_model = None
48
+ clip_preprocess = None
49
+ dtype = None
50
+ running_on_cpu = None
51
+
52
+ def __init__(self, content_dir):
53
+ self.loaded_categories = None
54
+ self.skip_categories = []
55
+ self.content_dir = content_dir
56
+ self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
57
+
58
+ def categories(self):
59
+ if not os.path.exists(self.content_dir):
60
+ download_default_clip_interrogate_categories(self.content_dir)
61
+
62
+ if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
63
+ return self.loaded_categories
64
+
65
+ self.loaded_categories = []
66
+
67
+ if os.path.exists(self.content_dir):
68
+ self.skip_categories = shared.opts.interrogate_clip_skip_categories
69
+ category_types = []
70
+ for filename in Path(self.content_dir).glob('*.txt'):
71
+ category_types.append(filename.stem)
72
+ if filename.stem in self.skip_categories:
73
+ continue
74
+ m = re_topn.search(filename.stem)
75
+ topn = 1 if m is None else int(m.group(1))
76
+ with open(filename, "r", encoding="utf8") as file:
77
+ lines = [x.strip() for x in file.readlines()]
78
+
79
+ self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
80
+
81
+ return self.loaded_categories
82
+
83
+ def create_fake_fairscale(self):
84
+ class FakeFairscale:
85
+ def checkpoint_wrapper(self):
86
+ pass
87
+
88
+ sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
89
+
90
+ def load_blip_model(self):
91
+ self.create_fake_fairscale()
92
+ import models.blip
93
+
94
+ files = modelloader.load_models(
95
+ model_path=os.path.join(paths.models_path, "BLIP"),
96
+ model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
97
+ ext_filter=[".pth"],
98
+ download_name='model_base_caption_capfilt_large.pth',
99
+ )
100
+
101
+ blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
102
+ blip_model.eval()
103
+
104
+ return blip_model
105
+
106
+ def load_clip_model(self):
107
+ import clip
108
+
109
+ if self.running_on_cpu:
110
+ model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
111
+ else:
112
+ model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
113
+
114
+ model.eval()
115
+ model = model.to(devices.device_interrogate)
116
+
117
+ return model, preprocess
118
+
119
+ def load(self):
120
+ if self.blip_model is None:
121
+ self.blip_model = self.load_blip_model()
122
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
123
+ self.blip_model = self.blip_model.half()
124
+
125
+ self.blip_model = self.blip_model.to(devices.device_interrogate)
126
+
127
+ if self.clip_model is None:
128
+ self.clip_model, self.clip_preprocess = self.load_clip_model()
129
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
130
+ self.clip_model = self.clip_model.half()
131
+
132
+ self.clip_model = self.clip_model.to(devices.device_interrogate)
133
+
134
+ self.dtype = next(self.clip_model.parameters()).dtype
135
+
136
+ def send_clip_to_ram(self):
137
+ if not shared.opts.interrogate_keep_models_in_memory:
138
+ if self.clip_model is not None:
139
+ self.clip_model = self.clip_model.to(devices.cpu)
140
+
141
+ def send_blip_to_ram(self):
142
+ if not shared.opts.interrogate_keep_models_in_memory:
143
+ if self.blip_model is not None:
144
+ self.blip_model = self.blip_model.to(devices.cpu)
145
+
146
+ def unload(self):
147
+ self.send_clip_to_ram()
148
+ self.send_blip_to_ram()
149
+
150
+ devices.torch_gc()
151
+
152
+ def rank(self, image_features, text_array, top_count=1):
153
+ import clip
154
+
155
+ devices.torch_gc()
156
+
157
+ if shared.opts.interrogate_clip_dict_limit != 0:
158
+ text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
159
+
160
+ top_count = min(top_count, len(text_array))
161
+ text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
162
+ text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
163
+ text_features /= text_features.norm(dim=-1, keepdim=True)
164
+
165
+ similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
166
+ for i in range(image_features.shape[0]):
167
+ similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
168
+ similarity /= image_features.shape[0]
169
+
170
+ top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
171
+ return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
172
+
173
+ def generate_caption(self, pil_image):
174
+ gpu_image = transforms.Compose([
175
+ transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
176
+ transforms.ToTensor(),
177
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
178
+ ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
179
+
180
+ with torch.no_grad():
181
+ caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
182
+
183
+ return caption[0]
184
+
185
+ def interrogate(self, pil_image):
186
+ res = ""
187
+ shared.state.begin(job="interrogate")
188
+ try:
189
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
190
+ lowvram.send_everything_to_cpu()
191
+ devices.torch_gc()
192
+
193
+ self.load()
194
+
195
+ caption = self.generate_caption(pil_image)
196
+ self.send_blip_to_ram()
197
+ devices.torch_gc()
198
+
199
+ res = caption
200
+
201
+ clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
202
+
203
+ with torch.no_grad(), devices.autocast():
204
+ image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
205
+
206
+ image_features /= image_features.norm(dim=-1, keepdim=True)
207
+
208
+ for cat in self.categories():
209
+ matches = self.rank(image_features, cat.items, top_count=cat.topn)
210
+ for match, score in matches:
211
+ if shared.opts.interrogate_return_ranks:
212
+ res += f", ({match}:{score/100:.3f})"
213
+ else:
214
+ res += f", {match}"
215
+
216
+ except Exception:
217
+ errors.report("Error interrogating", exc_info=True)
218
+ res += "<error>"
219
+
220
+ self.unload()
221
+ shared.state.end()
222
+
223
+ return res
modules/launch_utils.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this scripts installs necessary requirements and launches main program in webui.py
2
+ import re
3
+ import subprocess
4
+ import os
5
+ import sys
6
+ import importlib.util
7
+ import platform
8
+ import json
9
+ from functools import lru_cache
10
+
11
+ from modules import cmd_args, errors
12
+ from modules.paths_internal import script_path, extensions_dir
13
+ from modules.timer import startup_timer
14
+
15
+ args, _ = cmd_args.parser.parse_known_args()
16
+
17
+ python = sys.executable
18
+ git = os.environ.get('GIT', "git")
19
+ index_url = os.environ.get('INDEX_URL', "")
20
+ dir_repos = "repositories"
21
+
22
+ # Whether to default to printing command output
23
+ default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
24
+
25
+ if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
26
+ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
27
+
28
+
29
+ def check_python_version():
30
+ is_windows = platform.system() == "Windows"
31
+ major = sys.version_info.major
32
+ minor = sys.version_info.minor
33
+ micro = sys.version_info.micro
34
+
35
+ if is_windows:
36
+ supported_minors = [10]
37
+ else:
38
+ supported_minors = [7, 8, 9, 10, 11]
39
+
40
+ if not (major == 3 and minor in supported_minors):
41
+ import modules.errors
42
+
43
+ modules.errors.print_error_explanation(f"""
44
+ INCOMPATIBLE PYTHON VERSION
45
+
46
+ This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
47
+ If you encounter an error with "RuntimeError: Couldn't install torch." message,
48
+ or any other error regarding unsuccessful package (library) installation,
49
+ please downgrade (or upgrade) to the latest version of 3.10 Python
50
+ and delete current Python and "venv" folder in WebUI's directory.
51
+
52
+ You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
53
+
54
+ {"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
55
+
56
+ Use --skip-python-version-check to suppress this warning.
57
+ """)
58
+
59
+
60
+ @lru_cache()
61
+ def commit_hash():
62
+ try:
63
+ return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
64
+ except Exception:
65
+ return "<none>"
66
+
67
+
68
+ @lru_cache()
69
+ def git_tag():
70
+ try:
71
+ return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
72
+ except Exception:
73
+ try:
74
+
75
+ changelog_md = os.path.join(os.path.dirname(os.path.dirname(__file__)), "CHANGELOG.md")
76
+ with open(changelog_md, "r", encoding="utf-8") as file:
77
+ line = next((line.strip() for line in file if line.strip()), "<none>")
78
+ line = line.replace("## ", "")
79
+ return line
80
+ except Exception:
81
+ return "<none>"
82
+
83
+
84
+ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
85
+ if desc is not None:
86
+ print(desc)
87
+
88
+ run_kwargs = {
89
+ "args": command,
90
+ "shell": True,
91
+ "env": os.environ if custom_env is None else custom_env,
92
+ "encoding": 'utf8',
93
+ "errors": 'ignore',
94
+ }
95
+
96
+ if not live:
97
+ run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
98
+
99
+ result = subprocess.run(**run_kwargs)
100
+
101
+ if result.returncode != 0:
102
+ error_bits = [
103
+ f"{errdesc or 'Error running command'}.",
104
+ f"Command: {command}",
105
+ f"Error code: {result.returncode}",
106
+ ]
107
+ if result.stdout:
108
+ error_bits.append(f"stdout: {result.stdout}")
109
+ if result.stderr:
110
+ error_bits.append(f"stderr: {result.stderr}")
111
+ raise RuntimeError("\n".join(error_bits))
112
+
113
+ return (result.stdout or "")
114
+
115
+
116
+ def is_installed(package):
117
+ try:
118
+ spec = importlib.util.find_spec(package)
119
+ except ModuleNotFoundError:
120
+ return False
121
+
122
+ return spec is not None
123
+
124
+
125
+ def repo_dir(name):
126
+ return os.path.join(script_path, dir_repos, name)
127
+
128
+
129
+ def run_pip(command, desc=None, live=default_command_live):
130
+ if args.skip_install:
131
+ return
132
+
133
+ index_url_line = f' --index-url {index_url}' if index_url != '' else ''
134
+ return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
135
+
136
+
137
+ def check_run_python(code: str) -> bool:
138
+ result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
139
+ return result.returncode == 0
140
+
141
+
142
+ def git_clone(url, dir, name, commithash=None):
143
+ # TODO clone into temporary dir and move if successful
144
+
145
+ if os.path.exists(dir):
146
+ if commithash is None:
147
+ return
148
+
149
+ current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
150
+ if current_hash == commithash:
151
+ return
152
+
153
+ run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
154
+ run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
155
+ return
156
+
157
+ run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
158
+
159
+ if commithash is not None:
160
+ run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
161
+
162
+
163
+ def git_pull_recursive(dir):
164
+ for subdir, _, _ in os.walk(dir):
165
+ if os.path.exists(os.path.join(subdir, '.git')):
166
+ try:
167
+ output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
168
+ print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
169
+ except subprocess.CalledProcessError as e:
170
+ print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
171
+
172
+
173
+ def version_check(commit):
174
+ try:
175
+ import requests
176
+ commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
177
+ if commit != "<none>" and commits['commit']['sha'] != commit:
178
+ print("--------------------------------------------------------")
179
+ print("| You are not up to date with the most recent release. |")
180
+ print("| Consider running `git pull` to update. |")
181
+ print("--------------------------------------------------------")
182
+ elif commits['commit']['sha'] == commit:
183
+ print("You are up to date with the most recent release.")
184
+ else:
185
+ print("Not a git clone, can't perform version check.")
186
+ except Exception as e:
187
+ print("version check failed", e)
188
+
189
+
190
+ def run_extension_installer(extension_dir):
191
+ path_installer = os.path.join(extension_dir, "install.py")
192
+ if not os.path.isfile(path_installer):
193
+ return
194
+
195
+ try:
196
+ env = os.environ.copy()
197
+ env['PYTHONPATH'] = f"{os.path.abspath('.')}{os.pathsep}{env.get('PYTHONPATH', '')}"
198
+
199
+ print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
200
+ except Exception as e:
201
+ errors.report(str(e))
202
+
203
+
204
+ def list_extensions(settings_file):
205
+ settings = {}
206
+
207
+ try:
208
+ if os.path.isfile(settings_file):
209
+ with open(settings_file, "r", encoding="utf8") as file:
210
+ settings = json.load(file)
211
+ except Exception:
212
+ errors.report("Could not load settings", exc_info=True)
213
+
214
+ disabled_extensions = set(settings.get('disabled_extensions', []))
215
+ disable_all_extensions = settings.get('disable_all_extensions', 'none')
216
+
217
+ if disable_all_extensions != 'none':
218
+ return []
219
+
220
+ return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
221
+
222
+
223
+ def run_extensions_installers(settings_file):
224
+ if not os.path.isdir(extensions_dir):
225
+ return
226
+
227
+ with startup_timer.subcategory("run extensions installers"):
228
+ for dirname_extension in list_extensions(settings_file):
229
+ path = os.path.join(extensions_dir, dirname_extension)
230
+
231
+ if os.path.isdir(path):
232
+ run_extension_installer(path)
233
+ startup_timer.record(dirname_extension)
234
+
235
+
236
+ re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
237
+
238
+
239
+ def requirements_met(requirements_file):
240
+ """
241
+ Does a simple parse of a requirements.txt file to determine if all rerqirements in it
242
+ are already installed. Returns True if so, False if not installed or parsing fails.
243
+ """
244
+
245
+ import importlib.metadata
246
+ import packaging.version
247
+
248
+ with open(requirements_file, "r", encoding="utf8") as file:
249
+ for line in file:
250
+ if line.strip() == "":
251
+ continue
252
+
253
+ m = re.match(re_requirement, line)
254
+ if m is None:
255
+ return False
256
+
257
+ package = m.group(1).strip()
258
+ version_required = (m.group(2) or "").strip()
259
+
260
+ if version_required == "":
261
+ continue
262
+
263
+ try:
264
+ version_installed = importlib.metadata.version(package)
265
+ except Exception:
266
+ return False
267
+
268
+ if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
269
+ return False
270
+
271
+ return True
272
+
273
+
274
+ def prepare_environment():
275
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
276
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
277
+ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
278
+
279
+ xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
280
+ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
281
+ clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
282
+ openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
283
+
284
+ stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
285
+ stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
286
+ k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
287
+ codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
288
+ blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
289
+
290
+ stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
291
+ stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
292
+ k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
293
+ codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
294
+ blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
295
+
296
+ try:
297
+ # the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
298
+ os.remove(os.path.join(script_path, "tmp", "restart"))
299
+ os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
300
+ except OSError:
301
+ pass
302
+
303
+ if not args.skip_python_version_check:
304
+ check_python_version()
305
+
306
+ startup_timer.record("checks")
307
+
308
+ commit = commit_hash()
309
+ tag = git_tag()
310
+ startup_timer.record("git version info")
311
+
312
+ print(f"Python {sys.version}")
313
+ print(f"Version: {tag}")
314
+ print(f"Commit hash: {commit}")
315
+
316
+ if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
317
+ run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
318
+ startup_timer.record("install torch")
319
+
320
+ if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
321
+ raise RuntimeError(
322
+ 'Torch is not able to use GPU; '
323
+ 'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
324
+ )
325
+ startup_timer.record("torch GPU test")
326
+
327
+
328
+ if not is_installed("gfpgan"):
329
+ run_pip(f"install {gfpgan_package}", "gfpgan")
330
+ startup_timer.record("install gfpgan")
331
+
332
+ if not is_installed("clip"):
333
+ run_pip(f"install {clip_package}", "clip")
334
+ startup_timer.record("install clip")
335
+
336
+ if not is_installed("open_clip"):
337
+ run_pip(f"install {openclip_package}", "open_clip")
338
+ startup_timer.record("install open_clip")
339
+
340
+ if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
341
+ if platform.system() == "Windows":
342
+ if platform.python_version().startswith("3.10"):
343
+ run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
344
+ else:
345
+ print("Installation of xformers is not supported in this version of Python.")
346
+ print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
347
+ if not is_installed("xformers"):
348
+ exit(0)
349
+ elif platform.system() == "Linux":
350
+ run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
351
+
352
+ startup_timer.record("install xformers")
353
+
354
+ if not is_installed("ngrok") and args.ngrok:
355
+ run_pip("install ngrok", "ngrok")
356
+ startup_timer.record("install ngrok")
357
+
358
+ os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
359
+
360
+ git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
361
+ git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
362
+ git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
363
+ git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
364
+ git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
365
+
366
+ startup_timer.record("clone repositores")
367
+
368
+ if not is_installed("lpips"):
369
+ run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
370
+ startup_timer.record("install CodeFormer requirements")
371
+
372
+ if not os.path.isfile(requirements_file):
373
+ requirements_file = os.path.join(script_path, requirements_file)
374
+
375
+ if not requirements_met(requirements_file):
376
+ run_pip(f"install -r \"{requirements_file}\"", "requirements")
377
+ startup_timer.record("install requirements")
378
+
379
+ run_extensions_installers(settings_file=args.ui_settings_file)
380
+
381
+ if args.update_check:
382
+ version_check(commit)
383
+ startup_timer.record("check version")
384
+
385
+ if args.update_all_extensions:
386
+ git_pull_recursive(extensions_dir)
387
+ startup_timer.record("update extensions")
388
+
389
+ if "--exit" in sys.argv:
390
+ print("Exiting because of --exit argument")
391
+ exit(0)
392
+
393
+
394
+
395
+ def configure_for_tests():
396
+ if "--api" not in sys.argv:
397
+ sys.argv.append("--api")
398
+ if "--ckpt" not in sys.argv:
399
+ sys.argv.append("--ckpt")
400
+ sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
401
+ if "--skip-torch-cuda-test" not in sys.argv:
402
+ sys.argv.append("--skip-torch-cuda-test")
403
+ if "--disable-nan-check" not in sys.argv:
404
+ sys.argv.append("--disable-nan-check")
405
+
406
+ os.environ['COMMANDLINE_ARGS'] = ""
407
+
408
+
409
+ def start():
410
+ print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
411
+ import webui
412
+ if '--nowebui' in sys.argv:
413
+ webui.api_only()
414
+ else:
415
+ webui.webui()
modules/localization.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ from modules import errors
5
+
6
+ localizations = {}
7
+
8
+
9
+ def list_localizations(dirname):
10
+ localizations.clear()
11
+
12
+ for file in os.listdir(dirname):
13
+ fn, ext = os.path.splitext(file)
14
+ if ext.lower() != ".json":
15
+ continue
16
+
17
+ localizations[fn] = os.path.join(dirname, file)
18
+
19
+ from modules import scripts
20
+ for file in scripts.list_scripts("localizations", ".json"):
21
+ fn, ext = os.path.splitext(file.filename)
22
+ localizations[fn] = file.path
23
+
24
+
25
+ def localization_js(current_localization_name: str) -> str:
26
+ fn = localizations.get(current_localization_name, None)
27
+ data = {}
28
+ if fn is not None:
29
+ try:
30
+ with open(fn, "r", encoding="utf8") as file:
31
+ data = json.load(file)
32
+ except Exception:
33
+ errors.report(f"Error loading localization from {fn}", exc_info=True)
34
+
35
+ return f"window.localization = {json.dumps(data)}"
modules/lowvram.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from modules import devices
3
+
4
+ module_in_gpu = None
5
+ cpu = torch.device("cpu")
6
+
7
+
8
+ def send_everything_to_cpu():
9
+ global module_in_gpu
10
+
11
+ if module_in_gpu is not None:
12
+ module_in_gpu.to(cpu)
13
+
14
+ module_in_gpu = None
15
+
16
+
17
+ def setup_for_low_vram(sd_model, use_medvram):
18
+ sd_model.lowvram = True
19
+
20
+ parents = {}
21
+
22
+ def send_me_to_gpu(module, _):
23
+ """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
24
+ we add this as forward_pre_hook to a lot of modules and this way all but one of them will
25
+ be in CPU
26
+ """
27
+ global module_in_gpu
28
+
29
+ module = parents.get(module, module)
30
+
31
+ if module_in_gpu == module:
32
+ return
33
+
34
+ if module_in_gpu is not None:
35
+ module_in_gpu.to(cpu)
36
+
37
+ module.to(devices.device)
38
+ module_in_gpu = module
39
+
40
+ # see below for register_forward_pre_hook;
41
+ # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
42
+ # useless here, and we just replace those methods
43
+
44
+ first_stage_model = sd_model.first_stage_model
45
+ first_stage_model_encode = sd_model.first_stage_model.encode
46
+ first_stage_model_decode = sd_model.first_stage_model.decode
47
+
48
+ def first_stage_model_encode_wrap(x):
49
+ send_me_to_gpu(first_stage_model, None)
50
+ return first_stage_model_encode(x)
51
+
52
+ def first_stage_model_decode_wrap(z):
53
+ send_me_to_gpu(first_stage_model, None)
54
+ return first_stage_model_decode(z)
55
+
56
+ to_remain_in_cpu = [
57
+ (sd_model, 'first_stage_model'),
58
+ (sd_model, 'depth_model'),
59
+ (sd_model, 'embedder'),
60
+ (sd_model, 'model'),
61
+ (sd_model, 'embedder'),
62
+ ]
63
+
64
+ is_sdxl = hasattr(sd_model, 'conditioner')
65
+ is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
66
+
67
+ if is_sdxl:
68
+ to_remain_in_cpu.append((sd_model, 'conditioner'))
69
+ elif is_sd2:
70
+ to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
71
+ else:
72
+ to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
73
+
74
+ # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
75
+ stored = []
76
+ for obj, field in to_remain_in_cpu:
77
+ module = getattr(obj, field, None)
78
+ stored.append(module)
79
+ setattr(obj, field, None)
80
+
81
+ # send the model to GPU.
82
+ sd_model.to(devices.device)
83
+
84
+ # put modules back. the modules will be in CPU.
85
+ for (obj, field), module in zip(to_remain_in_cpu, stored):
86
+ setattr(obj, field, module)
87
+
88
+ # register hooks for those the first three models
89
+ if is_sdxl:
90
+ sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
91
+ elif is_sd2:
92
+ sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
93
+ sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)
94
+ parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
95
+ parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model
96
+ else:
97
+ sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
98
+ parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
99
+
100
+ sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
101
+ sd_model.first_stage_model.encode = first_stage_model_encode_wrap
102
+ sd_model.first_stage_model.decode = first_stage_model_decode_wrap
103
+ if sd_model.depth_model:
104
+ sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
105
+ if sd_model.embedder:
106
+ sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
107
+
108
+ if use_medvram:
109
+ sd_model.model.register_forward_pre_hook(send_me_to_gpu)
110
+ else:
111
+ diff_model = sd_model.model.diffusion_model
112
+
113
+ # the third remaining model is still too big for 4 GB, so we also do the same for its submodules
114
+ # so that only one of them is in GPU at a time
115
+ stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
116
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
117
+ sd_model.model.to(devices.device)
118
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
119
+
120
+ # install hooks for bits of third model
121
+ diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
122
+ for block in diff_model.input_blocks:
123
+ block.register_forward_pre_hook(send_me_to_gpu)
124
+ diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
125
+ for block in diff_model.output_blocks:
126
+ block.register_forward_pre_hook(send_me_to_gpu)
127
+
128
+
129
+ def is_enabled(sd_model):
130
+ return getattr(sd_model, 'lowvram', False)
modules/mac_specific.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import platform
5
+ from modules.sd_hijack_utils import CondFunc
6
+ from packaging import version
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
12
+ # use check `getattr` and try it for compatibility.
13
+ # in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
14
+ # since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
15
+ def check_for_mps() -> bool:
16
+ if version.parse(torch.__version__) <= version.parse("2.0.1"):
17
+ if not getattr(torch, 'has_mps', False):
18
+ return False
19
+ try:
20
+ torch.zeros(1).to(torch.device("mps"))
21
+ return True
22
+ except Exception:
23
+ return False
24
+ else:
25
+ return torch.backends.mps.is_available() and torch.backends.mps.is_built()
26
+
27
+
28
+ has_mps = check_for_mps()
29
+
30
+
31
+ def torch_mps_gc() -> None:
32
+ try:
33
+ from modules.shared import state
34
+ if state.current_latent is not None:
35
+ log.debug("`current_latent` is set, skipping MPS garbage collection")
36
+ return
37
+ from torch.mps import empty_cache
38
+ empty_cache()
39
+ except Exception:
40
+ log.warning("MPS garbage collection failed", exc_info=True)
41
+
42
+
43
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
44
+ def cumsum_fix(input, cumsum_func, *args, **kwargs):
45
+ if input.device.type == 'mps':
46
+ output_dtype = kwargs.get('dtype', input.dtype)
47
+ if output_dtype == torch.int64:
48
+ return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
49
+ elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
50
+ return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
51
+ return cumsum_func(input, *args, **kwargs)
52
+
53
+
54
+ if has_mps:
55
+ # MPS fix for randn in torchsde
56
+ CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
57
+
58
+ if platform.mac_ver()[0].startswith("13.2."):
59
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
60
+ CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
61
+
62
+ if version.parse(torch.__version__) < version.parse("1.13"):
63
+ # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
64
+
65
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
66
+ CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
67
+ lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
68
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
69
+ CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
70
+ lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
71
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
72
+ CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
73
+ elif version.parse(torch.__version__) > version.parse("1.13.1"):
74
+ cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
75
+ cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
76
+ CondFunc('torch.cumsum', cumsum_fix_func, None)
77
+ CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
78
+ CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
79
+
80
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
81
+ CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
82
+
83
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
84
+ if platform.processor() == 'i386':
85
+ for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
86
+ CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
modules/masking.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFilter, ImageOps
2
+
3
+
4
+ def get_crop_region(mask, pad=0):
5
+ """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
6
+ For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
7
+
8
+ h, w = mask.shape
9
+
10
+ crop_left = 0
11
+ for i in range(w):
12
+ if not (mask[:, i] == 0).all():
13
+ break
14
+ crop_left += 1
15
+
16
+ crop_right = 0
17
+ for i in reversed(range(w)):
18
+ if not (mask[:, i] == 0).all():
19
+ break
20
+ crop_right += 1
21
+
22
+ crop_top = 0
23
+ for i in range(h):
24
+ if not (mask[i] == 0).all():
25
+ break
26
+ crop_top += 1
27
+
28
+ crop_bottom = 0
29
+ for i in reversed(range(h)):
30
+ if not (mask[i] == 0).all():
31
+ break
32
+ crop_bottom += 1
33
+
34
+ return (
35
+ int(max(crop_left-pad, 0)),
36
+ int(max(crop_top-pad, 0)),
37
+ int(min(w - crop_right + pad, w)),
38
+ int(min(h - crop_bottom + pad, h))
39
+ )
40
+
41
+
42
+ def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
43
+ """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
44
+ for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
45
+
46
+ x1, y1, x2, y2 = crop_region
47
+
48
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
49
+ ratio_processing = processing_width / processing_height
50
+
51
+ if ratio_crop_region > ratio_processing:
52
+ desired_height = (x2 - x1) / ratio_processing
53
+ desired_height_diff = int(desired_height - (y2-y1))
54
+ y1 -= desired_height_diff//2
55
+ y2 += desired_height_diff - desired_height_diff//2
56
+ if y2 >= image_height:
57
+ diff = y2 - image_height
58
+ y2 -= diff
59
+ y1 -= diff
60
+ if y1 < 0:
61
+ y2 -= y1
62
+ y1 -= y1
63
+ if y2 >= image_height:
64
+ y2 = image_height
65
+ else:
66
+ desired_width = (y2 - y1) * ratio_processing
67
+ desired_width_diff = int(desired_width - (x2-x1))
68
+ x1 -= desired_width_diff//2
69
+ x2 += desired_width_diff - desired_width_diff//2
70
+ if x2 >= image_width:
71
+ diff = x2 - image_width
72
+ x2 -= diff
73
+ x1 -= diff
74
+ if x1 < 0:
75
+ x2 -= x1
76
+ x1 -= x1
77
+ if x2 >= image_width:
78
+ x2 = image_width
79
+
80
+ return x1, y1, x2, y2
81
+
82
+
83
+ def fill(image, mask):
84
+ """fills masked regions with colors from image using blur. Not extremely effective."""
85
+
86
+ image_mod = Image.new('RGBA', (image.width, image.height))
87
+
88
+ image_masked = Image.new('RGBa', (image.width, image.height))
89
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
90
+
91
+ image_masked = image_masked.convert('RGBa')
92
+
93
+ for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
94
+ blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
95
+ for _ in range(repeats):
96
+ image_mod.alpha_composite(blurred)
97
+
98
+ return image_mod.convert("RGB")
99
+
modules/memmon.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ from collections import defaultdict
4
+
5
+ import torch
6
+
7
+
8
+ class MemUsageMonitor(threading.Thread):
9
+ run_flag = None
10
+ device = None
11
+ disabled = False
12
+ opts = None
13
+ data = None
14
+
15
+ def __init__(self, name, device, opts):
16
+ threading.Thread.__init__(self)
17
+ self.name = name
18
+ self.device = device
19
+ self.opts = opts
20
+
21
+ self.daemon = True
22
+ self.run_flag = threading.Event()
23
+ self.data = defaultdict(int)
24
+
25
+ try:
26
+ self.cuda_mem_get_info()
27
+ torch.cuda.memory_stats(self.device)
28
+ except Exception as e: # AMD or whatever
29
+ print(f"Warning: caught exception '{e}', memory monitor disabled")
30
+ self.disabled = True
31
+
32
+ def cuda_mem_get_info(self):
33
+ index = self.device.index if self.device.index is not None else torch.cuda.current_device()
34
+ return torch.cuda.mem_get_info(index)
35
+
36
+ def run(self):
37
+ if self.disabled:
38
+ return
39
+
40
+ while True:
41
+ self.run_flag.wait()
42
+
43
+ torch.cuda.reset_peak_memory_stats()
44
+ self.data.clear()
45
+
46
+ if self.opts.memmon_poll_rate <= 0:
47
+ self.run_flag.clear()
48
+ continue
49
+
50
+ self.data["min_free"] = self.cuda_mem_get_info()[0]
51
+
52
+ while self.run_flag.is_set():
53
+ free, total = self.cuda_mem_get_info()
54
+ self.data["min_free"] = min(self.data["min_free"], free)
55
+
56
+ time.sleep(1 / self.opts.memmon_poll_rate)
57
+
58
+ def dump_debug(self):
59
+ print(self, 'recorded data:')
60
+ for k, v in self.read().items():
61
+ print(k, -(v // -(1024 ** 2)))
62
+
63
+ print(self, 'raw torch memory stats:')
64
+ tm = torch.cuda.memory_stats(self.device)
65
+ for k, v in tm.items():
66
+ if 'bytes' not in k:
67
+ continue
68
+ print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
69
+
70
+ print(torch.cuda.memory_summary())
71
+
72
+ def monitor(self):
73
+ self.run_flag.set()
74
+
75
+ def read(self):
76
+ if not self.disabled:
77
+ free, total = self.cuda_mem_get_info()
78
+ self.data["free"] = free
79
+ self.data["total"] = total
80
+
81
+ torch_stats = torch.cuda.memory_stats(self.device)
82
+ self.data["active"] = torch_stats["active.all.current"]
83
+ self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
84
+ self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
85
+ self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
86
+ self.data["system_peak"] = total - self.data["min_free"]
87
+
88
+ return self.data
89
+
90
+ def stop(self):
91
+ self.run_flag.clear()
92
+ return self.read()
modules/modelloader.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shutil
5
+ import importlib
6
+ from urllib.parse import urlparse
7
+
8
+ from modules import shared
9
+ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
10
+ from modules.paths import script_path, models_path
11
+
12
+
13
+ def load_file_from_url(
14
+ url: str,
15
+ *,
16
+ model_dir: str,
17
+ progress: bool = True,
18
+ file_name: str | None = None,
19
+ ) -> str:
20
+ """Download a file from `url` into `model_dir`, using the file present if possible.
21
+
22
+ Returns the path to the downloaded file.
23
+ """
24
+ os.makedirs(model_dir, exist_ok=True)
25
+ if not file_name:
26
+ parts = urlparse(url)
27
+ file_name = os.path.basename(parts.path)
28
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
29
+ if not os.path.exists(cached_file):
30
+ print(f'Downloading: "{url}" to {cached_file}\n')
31
+ from torch.hub import download_url_to_file
32
+ download_url_to_file(url, cached_file, progress=progress)
33
+ return cached_file
34
+
35
+
36
+ def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
37
+ """
38
+ A one-and done loader to try finding the desired models in specified directories.
39
+
40
+ @param download_name: Specify to download from model_url immediately.
41
+ @param model_url: If no other models are found, this will be downloaded on upscale.
42
+ @param model_path: The location to store/find models in.
43
+ @param command_path: A command-line argument to search for models in first.
44
+ @param ext_filter: An optional list of filename extensions to filter by
45
+ @return: A list of paths containing the desired model(s)
46
+ """
47
+ output = []
48
+
49
+ try:
50
+ places = []
51
+
52
+ if command_path is not None and command_path != model_path:
53
+ pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
54
+ if os.path.exists(pretrained_path):
55
+ print(f"Appending path: {pretrained_path}")
56
+ places.append(pretrained_path)
57
+ elif os.path.exists(command_path):
58
+ places.append(command_path)
59
+
60
+ places.append(model_path)
61
+
62
+ for place in places:
63
+ for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
64
+ if os.path.islink(full_path) and not os.path.exists(full_path):
65
+ print(f"Skipping broken symlink: {full_path}")
66
+ continue
67
+ if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
68
+ continue
69
+ if full_path not in output:
70
+ output.append(full_path)
71
+
72
+ if model_url is not None and len(output) == 0:
73
+ if download_name is not None:
74
+ output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
75
+ else:
76
+ output.append(model_url)
77
+
78
+ except Exception:
79
+ pass
80
+
81
+ return output
82
+
83
+
84
+ def friendly_name(file: str):
85
+ if file.startswith("http"):
86
+ file = urlparse(file).path
87
+
88
+ file = os.path.basename(file)
89
+ model_name, extension = os.path.splitext(file)
90
+ return model_name
91
+
92
+
93
+ def cleanup_models():
94
+ # This code could probably be more efficient if we used a tuple list or something to store the src/destinations
95
+ # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
96
+ # somehow auto-register and just do these things...
97
+ root_path = script_path
98
+ src_path = models_path
99
+ dest_path = os.path.join(models_path, "Stable-diffusion")
100
+ move_files(src_path, dest_path, ".ckpt")
101
+ move_files(src_path, dest_path, ".safetensors")
102
+ src_path = os.path.join(root_path, "ESRGAN")
103
+ dest_path = os.path.join(models_path, "ESRGAN")
104
+ move_files(src_path, dest_path)
105
+ src_path = os.path.join(models_path, "BSRGAN")
106
+ dest_path = os.path.join(models_path, "ESRGAN")
107
+ move_files(src_path, dest_path, ".pth")
108
+ src_path = os.path.join(root_path, "gfpgan")
109
+ dest_path = os.path.join(models_path, "GFPGAN")
110
+ move_files(src_path, dest_path)
111
+ src_path = os.path.join(root_path, "SwinIR")
112
+ dest_path = os.path.join(models_path, "SwinIR")
113
+ move_files(src_path, dest_path)
114
+ src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
115
+ dest_path = os.path.join(models_path, "LDSR")
116
+ move_files(src_path, dest_path)
117
+
118
+
119
+ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
120
+ try:
121
+ os.makedirs(dest_path, exist_ok=True)
122
+ if os.path.exists(src_path):
123
+ for file in os.listdir(src_path):
124
+ fullpath = os.path.join(src_path, file)
125
+ if os.path.isfile(fullpath):
126
+ if ext_filter is not None:
127
+ if ext_filter not in file:
128
+ continue
129
+ print(f"Moving {file} from {src_path} to {dest_path}.")
130
+ try:
131
+ shutil.move(fullpath, dest_path)
132
+ except Exception:
133
+ pass
134
+ if len(os.listdir(src_path)) == 0:
135
+ print(f"Removing empty folder: {src_path}")
136
+ shutil.rmtree(src_path, True)
137
+ except Exception:
138
+ pass
139
+
140
+
141
+ def load_upscalers():
142
+ # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
143
+ # so we'll try to import any _model.py files before looking in __subclasses__
144
+ modules_dir = os.path.join(shared.script_path, "modules")
145
+ for file in os.listdir(modules_dir):
146
+ if "_model.py" in file:
147
+ model_name = file.replace("_model.py", "")
148
+ full_model = f"modules.{model_name}_model"
149
+ try:
150
+ importlib.import_module(full_model)
151
+ except Exception:
152
+ pass
153
+
154
+ datas = []
155
+ commandline_options = vars(shared.cmd_opts)
156
+
157
+ # some of upscaler classes will not go away after reloading their modules, and we'll end
158
+ # up with two copies of those classes. The newest copy will always be the last in the list,
159
+ # so we go from end to beginning and ignore duplicates
160
+ used_classes = {}
161
+ for cls in reversed(Upscaler.__subclasses__()):
162
+ classname = str(cls)
163
+ if classname not in used_classes:
164
+ used_classes[classname] = cls
165
+
166
+ for cls in reversed(used_classes.values()):
167
+ name = cls.__name__
168
+ cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
169
+ commandline_model_path = commandline_options.get(cmd_name, None)
170
+ scaler = cls(commandline_model_path)
171
+ scaler.user_path = commandline_model_path
172
+ scaler.model_download_path = commandline_model_path or scaler.model_path
173
+ datas += scaler.scalers
174
+
175
+ shared.sd_upscalers = sorted(
176
+ datas,
177
+ # Special case for UpscalerNone keeps it at the beginning of the list.
178
+ key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
179
+ )
modules/models/diffusion/ddpm_edit.py ADDED
@@ -0,0 +1,1455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+
9
+ # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
10
+ # See more details in LICENSE.
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ import pytorch_lightning as pl
16
+ from torch.optim.lr_scheduler import LambdaLR
17
+ from einops import rearrange, repeat
18
+ from contextlib import contextmanager
19
+ from functools import partial
20
+ from tqdm import tqdm
21
+ from torchvision.utils import make_grid
22
+ from pytorch_lightning.utilities.distributed import rank_zero_only
23
+
24
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
25
+ from ldm.modules.ema import LitEma
26
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
27
+ from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
28
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
29
+ from ldm.models.diffusion.ddim import DDIMSampler
30
+
31
+
32
+ __conditioning_keys__ = {'concat': 'c_concat',
33
+ 'crossattn': 'c_crossattn',
34
+ 'adm': 'y'}
35
+
36
+
37
+ def disabled_train(self, mode=True):
38
+ """Overwrite model.train with this function to make sure train/eval mode
39
+ does not change anymore."""
40
+ return self
41
+
42
+
43
+ def uniform_on_device(r1, r2, shape, device):
44
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
45
+
46
+
47
+ class DDPM(pl.LightningModule):
48
+ # classic DDPM with Gaussian diffusion, in image space
49
+ def __init__(self,
50
+ unet_config,
51
+ timesteps=1000,
52
+ beta_schedule="linear",
53
+ loss_type="l2",
54
+ ckpt_path=None,
55
+ ignore_keys=None,
56
+ load_only_unet=False,
57
+ monitor="val/loss",
58
+ use_ema=True,
59
+ first_stage_key="image",
60
+ image_size=256,
61
+ channels=3,
62
+ log_every_t=100,
63
+ clip_denoised=True,
64
+ linear_start=1e-4,
65
+ linear_end=2e-2,
66
+ cosine_s=8e-3,
67
+ given_betas=None,
68
+ original_elbo_weight=0.,
69
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
70
+ l_simple_weight=1.,
71
+ conditioning_key=None,
72
+ parameterization="eps", # all assuming fixed variance schedules
73
+ scheduler_config=None,
74
+ use_positional_encodings=False,
75
+ learn_logvar=False,
76
+ logvar_init=0.,
77
+ load_ema=True,
78
+ ):
79
+ super().__init__()
80
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
81
+ self.parameterization = parameterization
82
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
83
+ self.cond_stage_model = None
84
+ self.clip_denoised = clip_denoised
85
+ self.log_every_t = log_every_t
86
+ self.first_stage_key = first_stage_key
87
+ self.image_size = image_size # try conv?
88
+ self.channels = channels
89
+ self.use_positional_encodings = use_positional_encodings
90
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
91
+ count_params(self.model, verbose=True)
92
+ self.use_ema = use_ema
93
+
94
+ self.use_scheduler = scheduler_config is not None
95
+ if self.use_scheduler:
96
+ self.scheduler_config = scheduler_config
97
+
98
+ self.v_posterior = v_posterior
99
+ self.original_elbo_weight = original_elbo_weight
100
+ self.l_simple_weight = l_simple_weight
101
+
102
+ if monitor is not None:
103
+ self.monitor = monitor
104
+
105
+ if self.use_ema and load_ema:
106
+ self.model_ema = LitEma(self.model)
107
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
108
+
109
+ if ckpt_path is not None:
110
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
111
+
112
+ # If initialing from EMA-only checkpoint, create EMA model after loading.
113
+ if self.use_ema and not load_ema:
114
+ self.model_ema = LitEma(self.model)
115
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
116
+
117
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
118
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
119
+
120
+ self.loss_type = loss_type
121
+
122
+ self.learn_logvar = learn_logvar
123
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
124
+ if self.learn_logvar:
125
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
126
+
127
+
128
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
129
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
130
+ if exists(given_betas):
131
+ betas = given_betas
132
+ else:
133
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
134
+ cosine_s=cosine_s)
135
+ alphas = 1. - betas
136
+ alphas_cumprod = np.cumprod(alphas, axis=0)
137
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
138
+
139
+ timesteps, = betas.shape
140
+ self.num_timesteps = int(timesteps)
141
+ self.linear_start = linear_start
142
+ self.linear_end = linear_end
143
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
144
+
145
+ to_torch = partial(torch.tensor, dtype=torch.float32)
146
+
147
+ self.register_buffer('betas', to_torch(betas))
148
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
149
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
150
+
151
+ # calculations for diffusion q(x_t | x_{t-1}) and others
152
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
153
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
154
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
155
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
156
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
157
+
158
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
159
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
160
+ 1. - alphas_cumprod) + self.v_posterior * betas
161
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
162
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
163
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
164
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
165
+ self.register_buffer('posterior_mean_coef1', to_torch(
166
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
167
+ self.register_buffer('posterior_mean_coef2', to_torch(
168
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
169
+
170
+ if self.parameterization == "eps":
171
+ lvlb_weights = self.betas ** 2 / (
172
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
173
+ elif self.parameterization == "x0":
174
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
175
+ else:
176
+ raise NotImplementedError("mu not supported")
177
+ # TODO how to choose this term
178
+ lvlb_weights[0] = lvlb_weights[1]
179
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
180
+ assert not torch.isnan(self.lvlb_weights).all()
181
+
182
+ @contextmanager
183
+ def ema_scope(self, context=None):
184
+ if self.use_ema:
185
+ self.model_ema.store(self.model.parameters())
186
+ self.model_ema.copy_to(self.model)
187
+ if context is not None:
188
+ print(f"{context}: Switched to EMA weights")
189
+ try:
190
+ yield None
191
+ finally:
192
+ if self.use_ema:
193
+ self.model_ema.restore(self.model.parameters())
194
+ if context is not None:
195
+ print(f"{context}: Restored training weights")
196
+
197
+ def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
198
+ ignore_keys = ignore_keys or []
199
+
200
+ sd = torch.load(path, map_location="cpu")
201
+ if "state_dict" in list(sd.keys()):
202
+ sd = sd["state_dict"]
203
+ keys = list(sd.keys())
204
+
205
+ # Our model adds additional channels to the first layer to condition on an input image.
206
+ # For the first layer, copy existing channel weights and initialize new channel weights to zero.
207
+ input_keys = [
208
+ "model.diffusion_model.input_blocks.0.0.weight",
209
+ "model_ema.diffusion_modelinput_blocks00weight",
210
+ ]
211
+
212
+ self_sd = self.state_dict()
213
+ for input_key in input_keys:
214
+ if input_key not in sd or input_key not in self_sd:
215
+ continue
216
+
217
+ input_weight = self_sd[input_key]
218
+
219
+ if input_weight.size() != sd[input_key].size():
220
+ print(f"Manual init: {input_key}")
221
+ input_weight.zero_()
222
+ input_weight[:, :4, :, :].copy_(sd[input_key])
223
+ ignore_keys.append(input_key)
224
+
225
+ for k in keys:
226
+ for ik in ignore_keys:
227
+ if k.startswith(ik):
228
+ print(f"Deleting key {k} from state_dict.")
229
+ del sd[k]
230
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
231
+ sd, strict=False)
232
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
233
+ if missing:
234
+ print(f"Missing Keys: {missing}")
235
+ if unexpected:
236
+ print(f"Unexpected Keys: {unexpected}")
237
+
238
+ def q_mean_variance(self, x_start, t):
239
+ """
240
+ Get the distribution q(x_t | x_0).
241
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
242
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
243
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
244
+ """
245
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
246
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
247
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
248
+ return mean, variance, log_variance
249
+
250
+ def predict_start_from_noise(self, x_t, t, noise):
251
+ return (
252
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
253
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
254
+ )
255
+
256
+ def q_posterior(self, x_start, x_t, t):
257
+ posterior_mean = (
258
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
259
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
260
+ )
261
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
262
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
263
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
264
+
265
+ def p_mean_variance(self, x, t, clip_denoised: bool):
266
+ model_out = self.model(x, t)
267
+ if self.parameterization == "eps":
268
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
269
+ elif self.parameterization == "x0":
270
+ x_recon = model_out
271
+ if clip_denoised:
272
+ x_recon.clamp_(-1., 1.)
273
+
274
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
275
+ return model_mean, posterior_variance, posterior_log_variance
276
+
277
+ @torch.no_grad()
278
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
279
+ b, *_, device = *x.shape, x.device
280
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
281
+ noise = noise_like(x.shape, device, repeat_noise)
282
+ # no noise when t == 0
283
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
284
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
285
+
286
+ @torch.no_grad()
287
+ def p_sample_loop(self, shape, return_intermediates=False):
288
+ device = self.betas.device
289
+ b = shape[0]
290
+ img = torch.randn(shape, device=device)
291
+ intermediates = [img]
292
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
293
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
294
+ clip_denoised=self.clip_denoised)
295
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
296
+ intermediates.append(img)
297
+ if return_intermediates:
298
+ return img, intermediates
299
+ return img
300
+
301
+ @torch.no_grad()
302
+ def sample(self, batch_size=16, return_intermediates=False):
303
+ image_size = self.image_size
304
+ channels = self.channels
305
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
306
+ return_intermediates=return_intermediates)
307
+
308
+ def q_sample(self, x_start, t, noise=None):
309
+ noise = default(noise, lambda: torch.randn_like(x_start))
310
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
311
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
312
+
313
+ def get_loss(self, pred, target, mean=True):
314
+ if self.loss_type == 'l1':
315
+ loss = (target - pred).abs()
316
+ if mean:
317
+ loss = loss.mean()
318
+ elif self.loss_type == 'l2':
319
+ if mean:
320
+ loss = torch.nn.functional.mse_loss(target, pred)
321
+ else:
322
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
323
+ else:
324
+ raise NotImplementedError("unknown loss type '{loss_type}'")
325
+
326
+ return loss
327
+
328
+ def p_losses(self, x_start, t, noise=None):
329
+ noise = default(noise, lambda: torch.randn_like(x_start))
330
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
331
+ model_out = self.model(x_noisy, t)
332
+
333
+ loss_dict = {}
334
+ if self.parameterization == "eps":
335
+ target = noise
336
+ elif self.parameterization == "x0":
337
+ target = x_start
338
+ else:
339
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
340
+
341
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
342
+
343
+ log_prefix = 'train' if self.training else 'val'
344
+
345
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
346
+ loss_simple = loss.mean() * self.l_simple_weight
347
+
348
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
349
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
350
+
351
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
352
+
353
+ loss_dict.update({f'{log_prefix}/loss': loss})
354
+
355
+ return loss, loss_dict
356
+
357
+ def forward(self, x, *args, **kwargs):
358
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
359
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
360
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
361
+ return self.p_losses(x, t, *args, **kwargs)
362
+
363
+ def get_input(self, batch, k):
364
+ return batch[k]
365
+
366
+ def shared_step(self, batch):
367
+ x = self.get_input(batch, self.first_stage_key)
368
+ loss, loss_dict = self(x)
369
+ return loss, loss_dict
370
+
371
+ def training_step(self, batch, batch_idx):
372
+ loss, loss_dict = self.shared_step(batch)
373
+
374
+ self.log_dict(loss_dict, prog_bar=True,
375
+ logger=True, on_step=True, on_epoch=True)
376
+
377
+ self.log("global_step", self.global_step,
378
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
379
+
380
+ if self.use_scheduler:
381
+ lr = self.optimizers().param_groups[0]['lr']
382
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
383
+
384
+ return loss
385
+
386
+ @torch.no_grad()
387
+ def validation_step(self, batch, batch_idx):
388
+ _, loss_dict_no_ema = self.shared_step(batch)
389
+ with self.ema_scope():
390
+ _, loss_dict_ema = self.shared_step(batch)
391
+ loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}
392
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
393
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
394
+
395
+ def on_train_batch_end(self, *args, **kwargs):
396
+ if self.use_ema:
397
+ self.model_ema(self.model)
398
+
399
+ def _get_rows_from_list(self, samples):
400
+ n_imgs_per_row = len(samples)
401
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
402
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
403
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
404
+ return denoise_grid
405
+
406
+ @torch.no_grad()
407
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
408
+ log = {}
409
+ x = self.get_input(batch, self.first_stage_key)
410
+ N = min(x.shape[0], N)
411
+ n_row = min(x.shape[0], n_row)
412
+ x = x.to(self.device)[:N]
413
+ log["inputs"] = x
414
+
415
+ # get diffusion row
416
+ diffusion_row = []
417
+ x_start = x[:n_row]
418
+
419
+ for t in range(self.num_timesteps):
420
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
421
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
422
+ t = t.to(self.device).long()
423
+ noise = torch.randn_like(x_start)
424
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
425
+ diffusion_row.append(x_noisy)
426
+
427
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
428
+
429
+ if sample:
430
+ # get denoise row
431
+ with self.ema_scope("Plotting"):
432
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
433
+
434
+ log["samples"] = samples
435
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
436
+
437
+ if return_keys:
438
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
439
+ return log
440
+ else:
441
+ return {key: log[key] for key in return_keys}
442
+ return log
443
+
444
+ def configure_optimizers(self):
445
+ lr = self.learning_rate
446
+ params = list(self.model.parameters())
447
+ if self.learn_logvar:
448
+ params = params + [self.logvar]
449
+ opt = torch.optim.AdamW(params, lr=lr)
450
+ return opt
451
+
452
+
453
+ class LatentDiffusion(DDPM):
454
+ """main class"""
455
+ def __init__(self,
456
+ first_stage_config,
457
+ cond_stage_config,
458
+ num_timesteps_cond=None,
459
+ cond_stage_key="image",
460
+ cond_stage_trainable=False,
461
+ concat_mode=True,
462
+ cond_stage_forward=None,
463
+ conditioning_key=None,
464
+ scale_factor=1.0,
465
+ scale_by_std=False,
466
+ load_ema=True,
467
+ *args, **kwargs):
468
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
469
+ self.scale_by_std = scale_by_std
470
+ assert self.num_timesteps_cond <= kwargs['timesteps']
471
+ # for backwards compatibility after implementation of DiffusionWrapper
472
+ if conditioning_key is None:
473
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
474
+ if cond_stage_config == '__is_unconditional__':
475
+ conditioning_key = None
476
+ ckpt_path = kwargs.pop("ckpt_path", None)
477
+ ignore_keys = kwargs.pop("ignore_keys", [])
478
+ super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
479
+ self.concat_mode = concat_mode
480
+ self.cond_stage_trainable = cond_stage_trainable
481
+ self.cond_stage_key = cond_stage_key
482
+ try:
483
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
484
+ except Exception:
485
+ self.num_downs = 0
486
+ if not scale_by_std:
487
+ self.scale_factor = scale_factor
488
+ else:
489
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
490
+ self.instantiate_first_stage(first_stage_config)
491
+ self.instantiate_cond_stage(cond_stage_config)
492
+ self.cond_stage_forward = cond_stage_forward
493
+ self.clip_denoised = False
494
+ self.bbox_tokenizer = None
495
+
496
+ self.restarted_from_ckpt = False
497
+ if ckpt_path is not None:
498
+ self.init_from_ckpt(ckpt_path, ignore_keys)
499
+ self.restarted_from_ckpt = True
500
+
501
+ if self.use_ema and not load_ema:
502
+ self.model_ema = LitEma(self.model)
503
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
504
+
505
+ def make_cond_schedule(self, ):
506
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
507
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
508
+ self.cond_ids[:self.num_timesteps_cond] = ids
509
+
510
+ @rank_zero_only
511
+ @torch.no_grad()
512
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
513
+ # only for very first batch
514
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
515
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
516
+ # set rescale weight to 1./std of encodings
517
+ print("### USING STD-RESCALING ###")
518
+ x = super().get_input(batch, self.first_stage_key)
519
+ x = x.to(self.device)
520
+ encoder_posterior = self.encode_first_stage(x)
521
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
522
+ del self.scale_factor
523
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
524
+ print(f"setting self.scale_factor to {self.scale_factor}")
525
+ print("### USING STD-RESCALING ###")
526
+
527
+ def register_schedule(self,
528
+ given_betas=None, beta_schedule="linear", timesteps=1000,
529
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
530
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
531
+
532
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
533
+ if self.shorten_cond_schedule:
534
+ self.make_cond_schedule()
535
+
536
+ def instantiate_first_stage(self, config):
537
+ model = instantiate_from_config(config)
538
+ self.first_stage_model = model.eval()
539
+ self.first_stage_model.train = disabled_train
540
+ for param in self.first_stage_model.parameters():
541
+ param.requires_grad = False
542
+
543
+ def instantiate_cond_stage(self, config):
544
+ if not self.cond_stage_trainable:
545
+ if config == "__is_first_stage__":
546
+ print("Using first stage also as cond stage.")
547
+ self.cond_stage_model = self.first_stage_model
548
+ elif config == "__is_unconditional__":
549
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
550
+ self.cond_stage_model = None
551
+ # self.be_unconditional = True
552
+ else:
553
+ model = instantiate_from_config(config)
554
+ self.cond_stage_model = model.eval()
555
+ self.cond_stage_model.train = disabled_train
556
+ for param in self.cond_stage_model.parameters():
557
+ param.requires_grad = False
558
+ else:
559
+ assert config != '__is_first_stage__'
560
+ assert config != '__is_unconditional__'
561
+ model = instantiate_from_config(config)
562
+ self.cond_stage_model = model
563
+
564
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
565
+ denoise_row = []
566
+ for zd in tqdm(samples, desc=desc):
567
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
568
+ force_not_quantize=force_no_decoder_quantization))
569
+ n_imgs_per_row = len(denoise_row)
570
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
571
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
572
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
573
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
574
+ return denoise_grid
575
+
576
+ def get_first_stage_encoding(self, encoder_posterior):
577
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
578
+ z = encoder_posterior.sample()
579
+ elif isinstance(encoder_posterior, torch.Tensor):
580
+ z = encoder_posterior
581
+ else:
582
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
583
+ return self.scale_factor * z
584
+
585
+ def get_learned_conditioning(self, c):
586
+ if self.cond_stage_forward is None:
587
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
588
+ c = self.cond_stage_model.encode(c)
589
+ if isinstance(c, DiagonalGaussianDistribution):
590
+ c = c.mode()
591
+ else:
592
+ c = self.cond_stage_model(c)
593
+ else:
594
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
595
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
596
+ return c
597
+
598
+ def meshgrid(self, h, w):
599
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
600
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
601
+
602
+ arr = torch.cat([y, x], dim=-1)
603
+ return arr
604
+
605
+ def delta_border(self, h, w):
606
+ """
607
+ :param h: height
608
+ :param w: width
609
+ :return: normalized distance to image border,
610
+ wtith min distance = 0 at border and max dist = 0.5 at image center
611
+ """
612
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
613
+ arr = self.meshgrid(h, w) / lower_right_corner
614
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
615
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
616
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
617
+ return edge_dist
618
+
619
+ def get_weighting(self, h, w, Ly, Lx, device):
620
+ weighting = self.delta_border(h, w)
621
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
622
+ self.split_input_params["clip_max_weight"], )
623
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
624
+
625
+ if self.split_input_params["tie_braker"]:
626
+ L_weighting = self.delta_border(Ly, Lx)
627
+ L_weighting = torch.clip(L_weighting,
628
+ self.split_input_params["clip_min_tie_weight"],
629
+ self.split_input_params["clip_max_tie_weight"])
630
+
631
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
632
+ weighting = weighting * L_weighting
633
+ return weighting
634
+
635
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
636
+ """
637
+ :param x: img of size (bs, c, h, w)
638
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
639
+ """
640
+ bs, nc, h, w = x.shape
641
+
642
+ # number of crops in image
643
+ Ly = (h - kernel_size[0]) // stride[0] + 1
644
+ Lx = (w - kernel_size[1]) // stride[1] + 1
645
+
646
+ if uf == 1 and df == 1:
647
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
648
+ unfold = torch.nn.Unfold(**fold_params)
649
+
650
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
651
+
652
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
653
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
654
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
655
+
656
+ elif uf > 1 and df == 1:
657
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
658
+ unfold = torch.nn.Unfold(**fold_params)
659
+
660
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
661
+ dilation=1, padding=0,
662
+ stride=(stride[0] * uf, stride[1] * uf))
663
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
664
+
665
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
666
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
667
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
668
+
669
+ elif df > 1 and uf == 1:
670
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
671
+ unfold = torch.nn.Unfold(**fold_params)
672
+
673
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
674
+ dilation=1, padding=0,
675
+ stride=(stride[0] // df, stride[1] // df))
676
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
677
+
678
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
679
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
680
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
681
+
682
+ else:
683
+ raise NotImplementedError
684
+
685
+ return fold, unfold, normalization, weighting
686
+
687
+ @torch.no_grad()
688
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
689
+ cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
690
+ x = super().get_input(batch, k)
691
+ if bs is not None:
692
+ x = x[:bs]
693
+ x = x.to(self.device)
694
+ encoder_posterior = self.encode_first_stage(x)
695
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
696
+ cond_key = cond_key or self.cond_stage_key
697
+ xc = super().get_input(batch, cond_key)
698
+ if bs is not None:
699
+ xc["c_crossattn"] = xc["c_crossattn"][:bs]
700
+ xc["c_concat"] = xc["c_concat"][:bs]
701
+ cond = {}
702
+
703
+ # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
704
+ random = torch.rand(x.size(0), device=x.device)
705
+ prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
706
+ input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
707
+
708
+ null_prompt = self.get_learned_conditioning([""])
709
+ cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]
710
+ cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()]
711
+
712
+ out = [z, cond]
713
+ if return_first_stage_outputs:
714
+ xrec = self.decode_first_stage(z)
715
+ out.extend([x, xrec])
716
+ if return_original_cond:
717
+ out.append(xc)
718
+ return out
719
+
720
+ @torch.no_grad()
721
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
722
+ if predict_cids:
723
+ if z.dim() == 4:
724
+ z = torch.argmax(z.exp(), dim=1).long()
725
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
726
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
727
+
728
+ z = 1. / self.scale_factor * z
729
+
730
+ if hasattr(self, "split_input_params"):
731
+ if self.split_input_params["patch_distributed_vq"]:
732
+ ks = self.split_input_params["ks"] # eg. (128, 128)
733
+ stride = self.split_input_params["stride"] # eg. (64, 64)
734
+ uf = self.split_input_params["vqf"]
735
+ bs, nc, h, w = z.shape
736
+ if ks[0] > h or ks[1] > w:
737
+ ks = (min(ks[0], h), min(ks[1], w))
738
+ print("reducing Kernel")
739
+
740
+ if stride[0] > h or stride[1] > w:
741
+ stride = (min(stride[0], h), min(stride[1], w))
742
+ print("reducing stride")
743
+
744
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
745
+
746
+ z = unfold(z) # (bn, nc * prod(**ks), L)
747
+ # 1. Reshape to img shape
748
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
749
+
750
+ # 2. apply model loop over last dim
751
+ if isinstance(self.first_stage_model, VQModelInterface):
752
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
753
+ force_not_quantize=predict_cids or force_not_quantize)
754
+ for i in range(z.shape[-1])]
755
+ else:
756
+
757
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
758
+ for i in range(z.shape[-1])]
759
+
760
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
761
+ o = o * weighting
762
+ # Reverse 1. reshape to img shape
763
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
764
+ # stitch crops together
765
+ decoded = fold(o)
766
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
767
+ return decoded
768
+ else:
769
+ if isinstance(self.first_stage_model, VQModelInterface):
770
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
771
+ else:
772
+ return self.first_stage_model.decode(z)
773
+
774
+ else:
775
+ if isinstance(self.first_stage_model, VQModelInterface):
776
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
777
+ else:
778
+ return self.first_stage_model.decode(z)
779
+
780
+ # same as above but without decorator
781
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
782
+ if predict_cids:
783
+ if z.dim() == 4:
784
+ z = torch.argmax(z.exp(), dim=1).long()
785
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
786
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
787
+
788
+ z = 1. / self.scale_factor * z
789
+
790
+ if hasattr(self, "split_input_params"):
791
+ if self.split_input_params["patch_distributed_vq"]:
792
+ ks = self.split_input_params["ks"] # eg. (128, 128)
793
+ stride = self.split_input_params["stride"] # eg. (64, 64)
794
+ uf = self.split_input_params["vqf"]
795
+ bs, nc, h, w = z.shape
796
+ if ks[0] > h or ks[1] > w:
797
+ ks = (min(ks[0], h), min(ks[1], w))
798
+ print("reducing Kernel")
799
+
800
+ if stride[0] > h or stride[1] > w:
801
+ stride = (min(stride[0], h), min(stride[1], w))
802
+ print("reducing stride")
803
+
804
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
805
+
806
+ z = unfold(z) # (bn, nc * prod(**ks), L)
807
+ # 1. Reshape to img shape
808
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
809
+
810
+ # 2. apply model loop over last dim
811
+ if isinstance(self.first_stage_model, VQModelInterface):
812
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
813
+ force_not_quantize=predict_cids or force_not_quantize)
814
+ for i in range(z.shape[-1])]
815
+ else:
816
+
817
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
818
+ for i in range(z.shape[-1])]
819
+
820
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
821
+ o = o * weighting
822
+ # Reverse 1. reshape to img shape
823
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
824
+ # stitch crops together
825
+ decoded = fold(o)
826
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
827
+ return decoded
828
+ else:
829
+ if isinstance(self.first_stage_model, VQModelInterface):
830
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
831
+ else:
832
+ return self.first_stage_model.decode(z)
833
+
834
+ else:
835
+ if isinstance(self.first_stage_model, VQModelInterface):
836
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
837
+ else:
838
+ return self.first_stage_model.decode(z)
839
+
840
+ @torch.no_grad()
841
+ def encode_first_stage(self, x):
842
+ if hasattr(self, "split_input_params"):
843
+ if self.split_input_params["patch_distributed_vq"]:
844
+ ks = self.split_input_params["ks"] # eg. (128, 128)
845
+ stride = self.split_input_params["stride"] # eg. (64, 64)
846
+ df = self.split_input_params["vqf"]
847
+ self.split_input_params['original_image_size'] = x.shape[-2:]
848
+ bs, nc, h, w = x.shape
849
+ if ks[0] > h or ks[1] > w:
850
+ ks = (min(ks[0], h), min(ks[1], w))
851
+ print("reducing Kernel")
852
+
853
+ if stride[0] > h or stride[1] > w:
854
+ stride = (min(stride[0], h), min(stride[1], w))
855
+ print("reducing stride")
856
+
857
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
858
+ z = unfold(x) # (bn, nc * prod(**ks), L)
859
+ # Reshape to img shape
860
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
861
+
862
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
863
+ for i in range(z.shape[-1])]
864
+
865
+ o = torch.stack(output_list, axis=-1)
866
+ o = o * weighting
867
+
868
+ # Reverse reshape to img shape
869
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
870
+ # stitch crops together
871
+ decoded = fold(o)
872
+ decoded = decoded / normalization
873
+ return decoded
874
+
875
+ else:
876
+ return self.first_stage_model.encode(x)
877
+ else:
878
+ return self.first_stage_model.encode(x)
879
+
880
+ def shared_step(self, batch, **kwargs):
881
+ x, c = self.get_input(batch, self.first_stage_key)
882
+ loss = self(x, c)
883
+ return loss
884
+
885
+ def forward(self, x, c, *args, **kwargs):
886
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
887
+ if self.model.conditioning_key is not None:
888
+ assert c is not None
889
+ if self.cond_stage_trainable:
890
+ c = self.get_learned_conditioning(c)
891
+ if self.shorten_cond_schedule: # TODO: drop this option
892
+ tc = self.cond_ids[t].to(self.device)
893
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
894
+ return self.p_losses(x, c, t, *args, **kwargs)
895
+
896
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
897
+
898
+ if isinstance(cond, dict):
899
+ # hybrid case, cond is exptected to be a dict
900
+ pass
901
+ else:
902
+ if not isinstance(cond, list):
903
+ cond = [cond]
904
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
905
+ cond = {key: cond}
906
+
907
+ if hasattr(self, "split_input_params"):
908
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
909
+ assert not return_ids
910
+ ks = self.split_input_params["ks"] # eg. (128, 128)
911
+ stride = self.split_input_params["stride"] # eg. (64, 64)
912
+
913
+ h, w = x_noisy.shape[-2:]
914
+
915
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
916
+
917
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
918
+ # Reshape to img shape
919
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
920
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
921
+
922
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
923
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
924
+ c_key = next(iter(cond.keys())) # get key
925
+ c = next(iter(cond.values())) # get value
926
+ assert (len(c) == 1) # todo extend to list with more than one elem
927
+ c = c[0] # get element
928
+
929
+ c = unfold(c)
930
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
931
+
932
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
933
+
934
+ elif self.cond_stage_key == 'coordinates_bbox':
935
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
936
+
937
+ # assuming padding of unfold is always 0 and its dilation is always 1
938
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
939
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
940
+ # as we are operating on latents, we need the factor from the original image size to the
941
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
942
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
943
+ rescale_latent = 2 ** (num_downs)
944
+
945
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
946
+ # need to rescale the tl patch coordinates to be in between (0,1)
947
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
948
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
949
+ for patch_nr in range(z.shape[-1])]
950
+
951
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
952
+ patch_limits = [(x_tl, y_tl,
953
+ rescale_latent * ks[0] / full_img_w,
954
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
955
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
956
+
957
+ # tokenize crop coordinates for the bounding boxes of the respective patches
958
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
959
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
960
+ print(patch_limits_tknzd[0].shape)
961
+ # cut tknzd crop position from conditioning
962
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
963
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
964
+ print(cut_cond.shape)
965
+
966
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
967
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
968
+ print(adapted_cond.shape)
969
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
970
+ print(adapted_cond.shape)
971
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
972
+ print(adapted_cond.shape)
973
+
974
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
975
+
976
+ else:
977
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
978
+
979
+ # apply model by loop over crops
980
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
981
+ assert not isinstance(output_list[0],
982
+ tuple) # todo cant deal with multiple model outputs check this never happens
983
+
984
+ o = torch.stack(output_list, axis=-1)
985
+ o = o * weighting
986
+ # Reverse reshape to img shape
987
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
988
+ # stitch crops together
989
+ x_recon = fold(o) / normalization
990
+
991
+ else:
992
+ x_recon = self.model(x_noisy, t, **cond)
993
+
994
+ if isinstance(x_recon, tuple) and not return_ids:
995
+ return x_recon[0]
996
+ else:
997
+ return x_recon
998
+
999
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1000
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
1001
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1002
+
1003
+ def _prior_bpd(self, x_start):
1004
+ """
1005
+ Get the prior KL term for the variational lower-bound, measured in
1006
+ bits-per-dim.
1007
+ This term can't be optimized, as it only depends on the encoder.
1008
+ :param x_start: the [N x C x ...] tensor of inputs.
1009
+ :return: a batch of [N] KL values (in bits), one per batch element.
1010
+ """
1011
+ batch_size = x_start.shape[0]
1012
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1013
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1014
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1015
+ return mean_flat(kl_prior) / np.log(2.0)
1016
+
1017
+ def p_losses(self, x_start, cond, t, noise=None):
1018
+ noise = default(noise, lambda: torch.randn_like(x_start))
1019
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1020
+ model_output = self.apply_model(x_noisy, t, cond)
1021
+
1022
+ loss_dict = {}
1023
+ prefix = 'train' if self.training else 'val'
1024
+
1025
+ if self.parameterization == "x0":
1026
+ target = x_start
1027
+ elif self.parameterization == "eps":
1028
+ target = noise
1029
+ else:
1030
+ raise NotImplementedError()
1031
+
1032
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
1033
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1034
+
1035
+ logvar_t = self.logvar[t].to(self.device)
1036
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1037
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1038
+ if self.learn_logvar:
1039
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1040
+ loss_dict.update({'logvar': self.logvar.data.mean()})
1041
+
1042
+ loss = self.l_simple_weight * loss.mean()
1043
+
1044
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
1045
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1046
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1047
+ loss += (self.original_elbo_weight * loss_vlb)
1048
+ loss_dict.update({f'{prefix}/loss': loss})
1049
+
1050
+ return loss, loss_dict
1051
+
1052
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1053
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
1054
+ t_in = t
1055
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1056
+
1057
+ if score_corrector is not None:
1058
+ assert self.parameterization == "eps"
1059
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1060
+
1061
+ if return_codebook_ids:
1062
+ model_out, logits = model_out
1063
+
1064
+ if self.parameterization == "eps":
1065
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1066
+ elif self.parameterization == "x0":
1067
+ x_recon = model_out
1068
+ else:
1069
+ raise NotImplementedError()
1070
+
1071
+ if clip_denoised:
1072
+ x_recon.clamp_(-1., 1.)
1073
+ if quantize_denoised:
1074
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1075
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1076
+ if return_codebook_ids:
1077
+ return model_mean, posterior_variance, posterior_log_variance, logits
1078
+ elif return_x0:
1079
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1080
+ else:
1081
+ return model_mean, posterior_variance, posterior_log_variance
1082
+
1083
+ @torch.no_grad()
1084
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
1085
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1086
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1087
+ b, *_, device = *x.shape, x.device
1088
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
1089
+ return_codebook_ids=return_codebook_ids,
1090
+ quantize_denoised=quantize_denoised,
1091
+ return_x0=return_x0,
1092
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1093
+ if return_codebook_ids:
1094
+ raise DeprecationWarning("Support dropped.")
1095
+ model_mean, _, model_log_variance, logits = outputs
1096
+ elif return_x0:
1097
+ model_mean, _, model_log_variance, x0 = outputs
1098
+ else:
1099
+ model_mean, _, model_log_variance = outputs
1100
+
1101
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1102
+ if noise_dropout > 0.:
1103
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1104
+ # no noise when t == 0
1105
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1106
+
1107
+ if return_codebook_ids:
1108
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1109
+ if return_x0:
1110
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1111
+ else:
1112
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1113
+
1114
+ @torch.no_grad()
1115
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1116
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1117
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1118
+ log_every_t=None):
1119
+ if not log_every_t:
1120
+ log_every_t = self.log_every_t
1121
+ timesteps = self.num_timesteps
1122
+ if batch_size is not None:
1123
+ b = batch_size if batch_size is not None else shape[0]
1124
+ shape = [batch_size] + list(shape)
1125
+ else:
1126
+ b = batch_size = shape[0]
1127
+ if x_T is None:
1128
+ img = torch.randn(shape, device=self.device)
1129
+ else:
1130
+ img = x_T
1131
+ intermediates = []
1132
+ if cond is not None:
1133
+ if isinstance(cond, dict):
1134
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1135
+ [x[:batch_size] for x in cond[key]] for key in cond}
1136
+ else:
1137
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1138
+
1139
+ if start_T is not None:
1140
+ timesteps = min(timesteps, start_T)
1141
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1142
+ total=timesteps) if verbose else reversed(
1143
+ range(0, timesteps))
1144
+ if type(temperature) == float:
1145
+ temperature = [temperature] * timesteps
1146
+
1147
+ for i in iterator:
1148
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1149
+ if self.shorten_cond_schedule:
1150
+ assert self.model.conditioning_key != 'hybrid'
1151
+ tc = self.cond_ids[ts].to(cond.device)
1152
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1153
+
1154
+ img, x0_partial = self.p_sample(img, cond, ts,
1155
+ clip_denoised=self.clip_denoised,
1156
+ quantize_denoised=quantize_denoised, return_x0=True,
1157
+ temperature=temperature[i], noise_dropout=noise_dropout,
1158
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1159
+ if mask is not None:
1160
+ assert x0 is not None
1161
+ img_orig = self.q_sample(x0, ts)
1162
+ img = img_orig * mask + (1. - mask) * img
1163
+
1164
+ if i % log_every_t == 0 or i == timesteps - 1:
1165
+ intermediates.append(x0_partial)
1166
+ if callback:
1167
+ callback(i)
1168
+ if img_callback:
1169
+ img_callback(img, i)
1170
+ return img, intermediates
1171
+
1172
+ @torch.no_grad()
1173
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
1174
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1175
+ mask=None, x0=None, img_callback=None, start_T=None,
1176
+ log_every_t=None):
1177
+
1178
+ if not log_every_t:
1179
+ log_every_t = self.log_every_t
1180
+ device = self.betas.device
1181
+ b = shape[0]
1182
+ if x_T is None:
1183
+ img = torch.randn(shape, device=device)
1184
+ else:
1185
+ img = x_T
1186
+
1187
+ intermediates = [img]
1188
+ if timesteps is None:
1189
+ timesteps = self.num_timesteps
1190
+
1191
+ if start_T is not None:
1192
+ timesteps = min(timesteps, start_T)
1193
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1194
+ range(0, timesteps))
1195
+
1196
+ if mask is not None:
1197
+ assert x0 is not None
1198
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1199
+
1200
+ for i in iterator:
1201
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1202
+ if self.shorten_cond_schedule:
1203
+ assert self.model.conditioning_key != 'hybrid'
1204
+ tc = self.cond_ids[ts].to(cond.device)
1205
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1206
+
1207
+ img = self.p_sample(img, cond, ts,
1208
+ clip_denoised=self.clip_denoised,
1209
+ quantize_denoised=quantize_denoised)
1210
+ if mask is not None:
1211
+ img_orig = self.q_sample(x0, ts)
1212
+ img = img_orig * mask + (1. - mask) * img
1213
+
1214
+ if i % log_every_t == 0 or i == timesteps - 1:
1215
+ intermediates.append(img)
1216
+ if callback:
1217
+ callback(i)
1218
+ if img_callback:
1219
+ img_callback(img, i)
1220
+
1221
+ if return_intermediates:
1222
+ return img, intermediates
1223
+ return img
1224
+
1225
+ @torch.no_grad()
1226
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1227
+ verbose=True, timesteps=None, quantize_denoised=False,
1228
+ mask=None, x0=None, shape=None,**kwargs):
1229
+ if shape is None:
1230
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1231
+ if cond is not None:
1232
+ if isinstance(cond, dict):
1233
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1234
+ [x[:batch_size] for x in cond[key]] for key in cond}
1235
+ else:
1236
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1237
+ return self.p_sample_loop(cond,
1238
+ shape,
1239
+ return_intermediates=return_intermediates, x_T=x_T,
1240
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1241
+ mask=mask, x0=x0)
1242
+
1243
+ @torch.no_grad()
1244
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1245
+
1246
+ if ddim:
1247
+ ddim_sampler = DDIMSampler(self)
1248
+ shape = (self.channels, self.image_size, self.image_size)
1249
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1250
+ shape,cond,verbose=False,**kwargs)
1251
+
1252
+ else:
1253
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1254
+ return_intermediates=True,**kwargs)
1255
+
1256
+ return samples, intermediates
1257
+
1258
+
1259
+ @torch.no_grad()
1260
+ def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1261
+ quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
1262
+ plot_diffusion_rows=False, **kwargs):
1263
+
1264
+ use_ddim = False
1265
+
1266
+ log = {}
1267
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1268
+ return_first_stage_outputs=True,
1269
+ force_c_encode=True,
1270
+ return_original_cond=True,
1271
+ bs=N, uncond=0)
1272
+ N = min(x.shape[0], N)
1273
+ n_row = min(x.shape[0], n_row)
1274
+ log["inputs"] = x
1275
+ log["reals"] = xc["c_concat"]
1276
+ log["reconstruction"] = xrec
1277
+ if self.model.conditioning_key is not None:
1278
+ if hasattr(self.cond_stage_model, "decode"):
1279
+ xc = self.cond_stage_model.decode(c)
1280
+ log["conditioning"] = xc
1281
+ elif self.cond_stage_key in ["caption"]:
1282
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
1283
+ log["conditioning"] = xc
1284
+ elif self.cond_stage_key == 'class_label':
1285
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1286
+ log['conditioning'] = xc
1287
+ elif isimage(xc):
1288
+ log["conditioning"] = xc
1289
+ if ismap(xc):
1290
+ log["original_conditioning"] = self.to_rgb(xc)
1291
+
1292
+ if plot_diffusion_rows:
1293
+ # get diffusion row
1294
+ diffusion_row = []
1295
+ z_start = z[:n_row]
1296
+ for t in range(self.num_timesteps):
1297
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1298
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1299
+ t = t.to(self.device).long()
1300
+ noise = torch.randn_like(z_start)
1301
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1302
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1303
+
1304
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1305
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1306
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1307
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1308
+ log["diffusion_row"] = diffusion_grid
1309
+
1310
+ if sample:
1311
+ # get denoise row
1312
+ with self.ema_scope("Plotting"):
1313
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1314
+ ddim_steps=ddim_steps,eta=ddim_eta)
1315
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1316
+ x_samples = self.decode_first_stage(samples)
1317
+ log["samples"] = x_samples
1318
+ if plot_denoise_rows:
1319
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1320
+ log["denoise_row"] = denoise_grid
1321
+
1322
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1323
+ self.first_stage_model, IdentityFirstStage):
1324
+ # also display when quantizing x0 while sampling
1325
+ with self.ema_scope("Plotting Quantized Denoised"):
1326
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1327
+ ddim_steps=ddim_steps,eta=ddim_eta,
1328
+ quantize_denoised=True)
1329
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1330
+ # quantize_denoised=True)
1331
+ x_samples = self.decode_first_stage(samples.to(self.device))
1332
+ log["samples_x0_quantized"] = x_samples
1333
+
1334
+ if inpaint:
1335
+ # make a simple center square
1336
+ h, w = z.shape[2], z.shape[3]
1337
+ mask = torch.ones(N, h, w).to(self.device)
1338
+ # zeros will be filled in
1339
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1340
+ mask = mask[:, None, ...]
1341
+ with self.ema_scope("Plotting Inpaint"):
1342
+
1343
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1344
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1345
+ x_samples = self.decode_first_stage(samples.to(self.device))
1346
+ log["samples_inpainting"] = x_samples
1347
+ log["mask"] = mask
1348
+
1349
+ # outpaint
1350
+ with self.ema_scope("Plotting Outpaint"):
1351
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1352
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1353
+ x_samples = self.decode_first_stage(samples.to(self.device))
1354
+ log["samples_outpainting"] = x_samples
1355
+
1356
+ if plot_progressive_rows:
1357
+ with self.ema_scope("Plotting Progressives"):
1358
+ img, progressives = self.progressive_denoising(c,
1359
+ shape=(self.channels, self.image_size, self.image_size),
1360
+ batch_size=N)
1361
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1362
+ log["progressive_row"] = prog_row
1363
+
1364
+ if return_keys:
1365
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1366
+ return log
1367
+ else:
1368
+ return {key: log[key] for key in return_keys}
1369
+ return log
1370
+
1371
+ def configure_optimizers(self):
1372
+ lr = self.learning_rate
1373
+ params = list(self.model.parameters())
1374
+ if self.cond_stage_trainable:
1375
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1376
+ params = params + list(self.cond_stage_model.parameters())
1377
+ if self.learn_logvar:
1378
+ print('Diffusion model optimizing logvar')
1379
+ params.append(self.logvar)
1380
+ opt = torch.optim.AdamW(params, lr=lr)
1381
+ if self.use_scheduler:
1382
+ assert 'target' in self.scheduler_config
1383
+ scheduler = instantiate_from_config(self.scheduler_config)
1384
+
1385
+ print("Setting up LambdaLR scheduler...")
1386
+ scheduler = [
1387
+ {
1388
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1389
+ 'interval': 'step',
1390
+ 'frequency': 1
1391
+ }]
1392
+ return [opt], scheduler
1393
+ return opt
1394
+
1395
+ @torch.no_grad()
1396
+ def to_rgb(self, x):
1397
+ x = x.float()
1398
+ if not hasattr(self, "colorize"):
1399
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1400
+ x = nn.functional.conv2d(x, weight=self.colorize)
1401
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1402
+ return x
1403
+
1404
+
1405
+ class DiffusionWrapper(pl.LightningModule):
1406
+ def __init__(self, diff_model_config, conditioning_key):
1407
+ super().__init__()
1408
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1409
+ self.conditioning_key = conditioning_key
1410
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
1411
+
1412
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
1413
+ if self.conditioning_key is None:
1414
+ out = self.diffusion_model(x, t)
1415
+ elif self.conditioning_key == 'concat':
1416
+ xc = torch.cat([x] + c_concat, dim=1)
1417
+ out = self.diffusion_model(xc, t)
1418
+ elif self.conditioning_key == 'crossattn':
1419
+ cc = torch.cat(c_crossattn, 1)
1420
+ out = self.diffusion_model(x, t, context=cc)
1421
+ elif self.conditioning_key == 'hybrid':
1422
+ xc = torch.cat([x] + c_concat, dim=1)
1423
+ cc = torch.cat(c_crossattn, 1)
1424
+ out = self.diffusion_model(xc, t, context=cc)
1425
+ elif self.conditioning_key == 'adm':
1426
+ cc = c_crossattn[0]
1427
+ out = self.diffusion_model(x, t, y=cc)
1428
+ else:
1429
+ raise NotImplementedError()
1430
+
1431
+ return out
1432
+
1433
+
1434
+ class Layout2ImgDiffusion(LatentDiffusion):
1435
+ # TODO: move all layout-specific hacks to this class
1436
+ def __init__(self, cond_stage_key, *args, **kwargs):
1437
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1438
+ super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
1439
+
1440
+ def log_images(self, batch, N=8, *args, **kwargs):
1441
+ logs = super().log_images(*args, batch=batch, N=N, **kwargs)
1442
+
1443
+ key = 'train' if self.training else 'validation'
1444
+ dset = self.trainer.datamodule.datasets[key]
1445
+ mapper = dset.conditional_builders[self.cond_stage_key]
1446
+
1447
+ bbox_imgs = []
1448
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1449
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
1450
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1451
+ bbox_imgs.append(bboximg)
1452
+
1453
+ cond_img = torch.stack(bbox_imgs, dim=0)
1454
+ logs['bbox_image'] = cond_img
1455
+ return logs
modules/models/diffusion/uni_pc/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import UniPCSampler # noqa: F401
modules/models/diffusion/uni_pc/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (204 Bytes). View file
 
modules/models/diffusion/uni_pc/__pycache__/sampler.cpython-310.pyc ADDED
Binary file (3.21 kB). View file
 
modules/models/diffusion/uni_pc/__pycache__/uni_pc.cpython-310.pyc ADDED
Binary file (27.6 kB). View file
 
modules/models/diffusion/uni_pc/sampler.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+
5
+ from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
6
+ from modules import shared, devices
7
+
8
+
9
+ class UniPCSampler(object):
10
+ def __init__(self, model, **kwargs):
11
+ super().__init__()
12
+ self.model = model
13
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
14
+ self.before_sample = None
15
+ self.after_sample = None
16
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != devices.device:
21
+ attr = attr.to(devices.device)
22
+ setattr(self, name, attr)
23
+
24
+ def set_hooks(self, before_sample, after_sample, after_update):
25
+ self.before_sample = before_sample
26
+ self.after_sample = after_sample
27
+ self.after_update = after_update
28
+
29
+ @torch.no_grad()
30
+ def sample(self,
31
+ S,
32
+ batch_size,
33
+ shape,
34
+ conditioning=None,
35
+ callback=None,
36
+ normals_sequence=None,
37
+ img_callback=None,
38
+ quantize_x0=False,
39
+ eta=0.,
40
+ mask=None,
41
+ x0=None,
42
+ temperature=1.,
43
+ noise_dropout=0.,
44
+ score_corrector=None,
45
+ corrector_kwargs=None,
46
+ verbose=True,
47
+ x_T=None,
48
+ log_every_t=100,
49
+ unconditional_guidance_scale=1.,
50
+ unconditional_conditioning=None,
51
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
52
+ **kwargs
53
+ ):
54
+ if conditioning is not None:
55
+ if isinstance(conditioning, dict):
56
+ ctmp = conditioning[list(conditioning.keys())[0]]
57
+ while isinstance(ctmp, list):
58
+ ctmp = ctmp[0]
59
+ cbs = ctmp.shape[0]
60
+ if cbs != batch_size:
61
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
62
+
63
+ elif isinstance(conditioning, list):
64
+ for ctmp in conditioning:
65
+ if ctmp.shape[0] != batch_size:
66
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
67
+
68
+ else:
69
+ if conditioning.shape[0] != batch_size:
70
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
71
+
72
+ # sampling
73
+ C, H, W = shape
74
+ size = (batch_size, C, H, W)
75
+ # print(f'Data shape for UniPC sampling is {size}')
76
+
77
+ device = self.model.betas.device
78
+ if x_T is None:
79
+ img = torch.randn(size, device=device)
80
+ else:
81
+ img = x_T
82
+
83
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
84
+
85
+ # SD 1.X is "noise", SD 2.X is "v"
86
+ model_type = "v" if self.model.parameterization == "v" else "noise"
87
+
88
+ model_fn = model_wrapper(
89
+ lambda x, t, c: self.model.apply_model(x, t, c),
90
+ ns,
91
+ model_type=model_type,
92
+ guidance_type="classifier-free",
93
+ #condition=conditioning,
94
+ #unconditional_condition=unconditional_conditioning,
95
+ guidance_scale=unconditional_guidance_scale,
96
+ )
97
+
98
+ uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
99
+ x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
100
+
101
+ return x.to(device), None
modules/models/diffusion/uni_pc/uni_pc.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import tqdm
4
+
5
+
6
+ class NoiseScheduleVP:
7
+ def __init__(
8
+ self,
9
+ schedule='discrete',
10
+ betas=None,
11
+ alphas_cumprod=None,
12
+ continuous_beta_0=0.1,
13
+ continuous_beta_1=20.,
14
+ ):
15
+ """Create a wrapper class for the forward SDE (VP type).
16
+
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+
22
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
23
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
24
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
25
+
26
+ log_alpha_t = self.marginal_log_mean_coeff(t)
27
+ sigma_t = self.marginal_std(t)
28
+ lambda_t = self.marginal_lambda(t)
29
+
30
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
31
+
32
+ t = self.inverse_lambda(lambda_t)
33
+
34
+ ===============================================================
35
+
36
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
37
+
38
+ 1. For discrete-time DPMs:
39
+
40
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
41
+ t_i = (i + 1) / N
42
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
43
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
44
+
45
+ Args:
46
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
47
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
48
+
49
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
50
+
51
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
52
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
53
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
54
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
55
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
56
+ and
57
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
58
+
59
+
60
+ 2. For continuous-time DPMs:
61
+
62
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
63
+ schedule are the default settings in DDPM and improved-DDPM:
64
+
65
+ Args:
66
+ beta_min: A `float` number. The smallest beta for the linear schedule.
67
+ beta_max: A `float` number. The largest beta for the linear schedule.
68
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
69
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
70
+ T: A `float` number. The ending time of the forward process.
71
+
72
+ ===============================================================
73
+
74
+ Args:
75
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
76
+ 'linear' or 'cosine' for continuous-time DPMs.
77
+ Returns:
78
+ A wrapper object of the forward SDE (VP type).
79
+
80
+ ===============================================================
81
+
82
+ Example:
83
+
84
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
85
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
86
+
87
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
88
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
89
+
90
+ # For continuous-time DPMs (VPSDE), linear schedule:
91
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
92
+
93
+ """
94
+
95
+ if schedule not in ['discrete', 'linear', 'cosine']:
96
+ raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
97
+
98
+ self.schedule = schedule
99
+ if schedule == 'discrete':
100
+ if betas is not None:
101
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
102
+ else:
103
+ assert alphas_cumprod is not None
104
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
105
+ self.total_N = len(log_alphas)
106
+ self.T = 1.
107
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
108
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
109
+ else:
110
+ self.total_N = 1000
111
+ self.beta_0 = continuous_beta_0
112
+ self.beta_1 = continuous_beta_1
113
+ self.cosine_s = 0.008
114
+ self.cosine_beta_max = 999.
115
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
116
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
117
+ self.schedule = schedule
118
+ if schedule == 'cosine':
119
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
120
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
121
+ self.T = 0.9946
122
+ else:
123
+ self.T = 1.
124
+
125
+ def marginal_log_mean_coeff(self, t):
126
+ """
127
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
128
+ """
129
+ if self.schedule == 'discrete':
130
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
131
+ elif self.schedule == 'linear':
132
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
133
+ elif self.schedule == 'cosine':
134
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
135
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
136
+ return log_alpha_t
137
+
138
+ def marginal_alpha(self, t):
139
+ """
140
+ Compute alpha_t of a given continuous-time label t in [0, T].
141
+ """
142
+ return torch.exp(self.marginal_log_mean_coeff(t))
143
+
144
+ def marginal_std(self, t):
145
+ """
146
+ Compute sigma_t of a given continuous-time label t in [0, T].
147
+ """
148
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
149
+
150
+ def marginal_lambda(self, t):
151
+ """
152
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
153
+ """
154
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
155
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
156
+ return log_mean_coeff - log_std
157
+
158
+ def inverse_lambda(self, lamb):
159
+ """
160
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
161
+ """
162
+ if self.schedule == 'linear':
163
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
164
+ Delta = self.beta_0**2 + tmp
165
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
166
+ elif self.schedule == 'discrete':
167
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
168
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
169
+ return t.reshape((-1,))
170
+ else:
171
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
172
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
173
+ t = t_fn(log_alpha)
174
+ return t
175
+
176
+
177
+ def model_wrapper(
178
+ model,
179
+ noise_schedule,
180
+ model_type="noise",
181
+ model_kwargs=None,
182
+ guidance_type="uncond",
183
+ #condition=None,
184
+ #unconditional_condition=None,
185
+ guidance_scale=1.,
186
+ classifier_fn=None,
187
+ classifier_kwargs=None,
188
+ ):
189
+ """Create a wrapper function for the noise prediction model.
190
+
191
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
192
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
193
+
194
+ We support four types of the diffusion model by setting `model_type`:
195
+
196
+ 1. "noise": noise prediction model. (Trained by predicting noise).
197
+
198
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
199
+
200
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
201
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
202
+
203
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
204
+ arXiv preprint arXiv:2202.00512 (2022).
205
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
206
+ arXiv preprint arXiv:2210.02303 (2022).
207
+
208
+ 4. "score": marginal score function. (Trained by denoising score matching).
209
+ Note that the score function and the noise prediction model follows a simple relationship:
210
+ ```
211
+ noise(x_t, t) = -sigma_t * score(x_t, t)
212
+ ```
213
+
214
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
215
+ 1. "uncond": unconditional sampling by DPMs.
216
+ The input `model` has the following format:
217
+ ``
218
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
219
+ ``
220
+
221
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
222
+ The input `model` has the following format:
223
+ ``
224
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
225
+ ``
226
+
227
+ The input `classifier_fn` has the following format:
228
+ ``
229
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
230
+ ``
231
+
232
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
233
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
234
+
235
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
236
+ The input `model` has the following format:
237
+ ``
238
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
239
+ ``
240
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
241
+
242
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
243
+ arXiv preprint arXiv:2207.12598 (2022).
244
+
245
+
246
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
247
+ or continuous-time labels (i.e. epsilon to T).
248
+
249
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
250
+ ``
251
+ def model_fn(x, t_continuous) -> noise:
252
+ t_input = get_model_input_time(t_continuous)
253
+ return noise_pred(model, x, t_input, **model_kwargs)
254
+ ``
255
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
256
+
257
+ ===============================================================
258
+
259
+ Args:
260
+ model: A diffusion model with the corresponding format described above.
261
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
262
+ model_type: A `str`. The parameterization type of the diffusion model.
263
+ "noise" or "x_start" or "v" or "score".
264
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
265
+ guidance_type: A `str`. The type of the guidance for sampling.
266
+ "uncond" or "classifier" or "classifier-free".
267
+ condition: A pytorch tensor. The condition for the guided sampling.
268
+ Only used for "classifier" or "classifier-free" guidance type.
269
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
270
+ Only used for "classifier-free" guidance type.
271
+ guidance_scale: A `float`. The scale for the guided sampling.
272
+ classifier_fn: A classifier function. Only used for the classifier guidance.
273
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
274
+ Returns:
275
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
276
+ """
277
+
278
+ model_kwargs = model_kwargs or {}
279
+ classifier_kwargs = classifier_kwargs or {}
280
+
281
+ def get_model_input_time(t_continuous):
282
+ """
283
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
284
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
285
+ For continuous-time DPMs, we just use `t_continuous`.
286
+ """
287
+ if noise_schedule.schedule == 'discrete':
288
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
289
+ else:
290
+ return t_continuous
291
+
292
+ def noise_pred_fn(x, t_continuous, cond=None):
293
+ if t_continuous.reshape((-1,)).shape[0] == 1:
294
+ t_continuous = t_continuous.expand((x.shape[0]))
295
+ t_input = get_model_input_time(t_continuous)
296
+ if cond is None:
297
+ output = model(x, t_input, None, **model_kwargs)
298
+ else:
299
+ output = model(x, t_input, cond, **model_kwargs)
300
+ if model_type == "noise":
301
+ return output
302
+ elif model_type == "x_start":
303
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
304
+ dims = x.dim()
305
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
306
+ elif model_type == "v":
307
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
308
+ dims = x.dim()
309
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
310
+ elif model_type == "score":
311
+ sigma_t = noise_schedule.marginal_std(t_continuous)
312
+ dims = x.dim()
313
+ return -expand_dims(sigma_t, dims) * output
314
+
315
+ def cond_grad_fn(x, t_input, condition):
316
+ """
317
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
318
+ """
319
+ with torch.enable_grad():
320
+ x_in = x.detach().requires_grad_(True)
321
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
322
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
323
+
324
+ def model_fn(x, t_continuous, condition, unconditional_condition):
325
+ """
326
+ The noise predicition model function that is used for DPM-Solver.
327
+ """
328
+ if t_continuous.reshape((-1,)).shape[0] == 1:
329
+ t_continuous = t_continuous.expand((x.shape[0]))
330
+ if guidance_type == "uncond":
331
+ return noise_pred_fn(x, t_continuous)
332
+ elif guidance_type == "classifier":
333
+ assert classifier_fn is not None
334
+ t_input = get_model_input_time(t_continuous)
335
+ cond_grad = cond_grad_fn(x, t_input, condition)
336
+ sigma_t = noise_schedule.marginal_std(t_continuous)
337
+ noise = noise_pred_fn(x, t_continuous)
338
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
339
+ elif guidance_type == "classifier-free":
340
+ if guidance_scale == 1. or unconditional_condition is None:
341
+ return noise_pred_fn(x, t_continuous, cond=condition)
342
+ else:
343
+ x_in = torch.cat([x] * 2)
344
+ t_in = torch.cat([t_continuous] * 2)
345
+ if isinstance(condition, dict):
346
+ assert isinstance(unconditional_condition, dict)
347
+ c_in = {}
348
+ for k in condition:
349
+ if isinstance(condition[k], list):
350
+ c_in[k] = [torch.cat([
351
+ unconditional_condition[k][i],
352
+ condition[k][i]]) for i in range(len(condition[k]))]
353
+ else:
354
+ c_in[k] = torch.cat([
355
+ unconditional_condition[k],
356
+ condition[k]])
357
+ elif isinstance(condition, list):
358
+ c_in = []
359
+ assert isinstance(unconditional_condition, list)
360
+ for i in range(len(condition)):
361
+ c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
362
+ else:
363
+ c_in = torch.cat([unconditional_condition, condition])
364
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
365
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
366
+
367
+ assert model_type in ["noise", "x_start", "v"]
368
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
369
+ return model_fn
370
+
371
+
372
+ class UniPC:
373
+ def __init__(
374
+ self,
375
+ model_fn,
376
+ noise_schedule,
377
+ predict_x0=True,
378
+ thresholding=False,
379
+ max_val=1.,
380
+ variant='bh1',
381
+ condition=None,
382
+ unconditional_condition=None,
383
+ before_sample=None,
384
+ after_sample=None,
385
+ after_update=None
386
+ ):
387
+ """Construct a UniPC.
388
+
389
+ We support both data_prediction and noise_prediction.
390
+ """
391
+ self.model_fn_ = model_fn
392
+ self.noise_schedule = noise_schedule
393
+ self.variant = variant
394
+ self.predict_x0 = predict_x0
395
+ self.thresholding = thresholding
396
+ self.max_val = max_val
397
+ self.condition = condition
398
+ self.unconditional_condition = unconditional_condition
399
+ self.before_sample = before_sample
400
+ self.after_sample = after_sample
401
+ self.after_update = after_update
402
+
403
+ def dynamic_thresholding_fn(self, x0, t=None):
404
+ """
405
+ The dynamic thresholding method.
406
+ """
407
+ dims = x0.dim()
408
+ p = self.dynamic_thresholding_ratio
409
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
410
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
411
+ x0 = torch.clamp(x0, -s, s) / s
412
+ return x0
413
+
414
+ def model(self, x, t):
415
+ cond = self.condition
416
+ uncond = self.unconditional_condition
417
+ if self.before_sample is not None:
418
+ x, t, cond, uncond = self.before_sample(x, t, cond, uncond)
419
+ res = self.model_fn_(x, t, cond, uncond)
420
+ if self.after_sample is not None:
421
+ x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res)
422
+
423
+ if isinstance(res, tuple):
424
+ # (None, pred_x0)
425
+ res = res[1]
426
+
427
+ return res
428
+
429
+ def noise_prediction_fn(self, x, t):
430
+ """
431
+ Return the noise prediction model.
432
+ """
433
+ return self.model(x, t)
434
+
435
+ def data_prediction_fn(self, x, t):
436
+ """
437
+ Return the data prediction model (with thresholding).
438
+ """
439
+ noise = self.noise_prediction_fn(x, t)
440
+ dims = x.dim()
441
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
442
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
443
+ if self.thresholding:
444
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
445
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
446
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
447
+ x0 = torch.clamp(x0, -s, s) / s
448
+ return x0
449
+
450
+ def model_fn(self, x, t):
451
+ """
452
+ Convert the model to the noise prediction model or the data prediction model.
453
+ """
454
+ if self.predict_x0:
455
+ return self.data_prediction_fn(x, t)
456
+ else:
457
+ return self.noise_prediction_fn(x, t)
458
+
459
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
460
+ """Compute the intermediate time steps for sampling.
461
+ """
462
+ if skip_type == 'logSNR':
463
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
464
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
465
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
466
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
467
+ elif skip_type == 'time_uniform':
468
+ return torch.linspace(t_T, t_0, N + 1).to(device)
469
+ elif skip_type == 'time_quadratic':
470
+ t_order = 2
471
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
472
+ return t
473
+ else:
474
+ raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
475
+
476
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
477
+ """
478
+ Get the order of each step for sampling by the singlestep DPM-Solver.
479
+ """
480
+ if order == 3:
481
+ K = steps // 3 + 1
482
+ if steps % 3 == 0:
483
+ orders = [3,] * (K - 2) + [2, 1]
484
+ elif steps % 3 == 1:
485
+ orders = [3,] * (K - 1) + [1]
486
+ else:
487
+ orders = [3,] * (K - 1) + [2]
488
+ elif order == 2:
489
+ if steps % 2 == 0:
490
+ K = steps // 2
491
+ orders = [2,] * K
492
+ else:
493
+ K = steps // 2 + 1
494
+ orders = [2,] * (K - 1) + [1]
495
+ elif order == 1:
496
+ K = steps
497
+ orders = [1,] * steps
498
+ else:
499
+ raise ValueError("'order' must be '1' or '2' or '3'.")
500
+ if skip_type == 'logSNR':
501
+ # To reproduce the results in DPM-Solver paper
502
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
503
+ else:
504
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
505
+ return timesteps_outer, orders
506
+
507
+ def denoise_to_zero_fn(self, x, s):
508
+ """
509
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
510
+ """
511
+ return self.data_prediction_fn(x, s)
512
+
513
+ def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
514
+ if len(t.shape) == 0:
515
+ t = t.view(-1)
516
+ if 'bh' in self.variant:
517
+ return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
518
+ else:
519
+ assert self.variant == 'vary_coeff'
520
+ return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
521
+
522
+ def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
523
+ #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
524
+ ns = self.noise_schedule
525
+ assert order <= len(model_prev_list)
526
+
527
+ # first compute rks
528
+ t_prev_0 = t_prev_list[-1]
529
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
530
+ lambda_t = ns.marginal_lambda(t)
531
+ model_prev_0 = model_prev_list[-1]
532
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
533
+ log_alpha_t = ns.marginal_log_mean_coeff(t)
534
+ alpha_t = torch.exp(log_alpha_t)
535
+
536
+ h = lambda_t - lambda_prev_0
537
+
538
+ rks = []
539
+ D1s = []
540
+ for i in range(1, order):
541
+ t_prev_i = t_prev_list[-(i + 1)]
542
+ model_prev_i = model_prev_list[-(i + 1)]
543
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
544
+ rk = (lambda_prev_i - lambda_prev_0) / h
545
+ rks.append(rk)
546
+ D1s.append((model_prev_i - model_prev_0) / rk)
547
+
548
+ rks.append(1.)
549
+ rks = torch.tensor(rks, device=x.device)
550
+
551
+ K = len(rks)
552
+ # build C matrix
553
+ C = []
554
+
555
+ col = torch.ones_like(rks)
556
+ for k in range(1, K + 1):
557
+ C.append(col)
558
+ col = col * rks / (k + 1)
559
+ C = torch.stack(C, dim=1)
560
+
561
+ if len(D1s) > 0:
562
+ D1s = torch.stack(D1s, dim=1) # (B, K)
563
+ C_inv_p = torch.linalg.inv(C[:-1, :-1])
564
+ A_p = C_inv_p
565
+
566
+ if use_corrector:
567
+ #print('using corrector')
568
+ C_inv = torch.linalg.inv(C)
569
+ A_c = C_inv
570
+
571
+ hh = -h if self.predict_x0 else h
572
+ h_phi_1 = torch.expm1(hh)
573
+ h_phi_ks = []
574
+ factorial_k = 1
575
+ h_phi_k = h_phi_1
576
+ for k in range(1, K + 2):
577
+ h_phi_ks.append(h_phi_k)
578
+ h_phi_k = h_phi_k / hh - 1 / factorial_k
579
+ factorial_k *= (k + 1)
580
+
581
+ model_t = None
582
+ if self.predict_x0:
583
+ x_t_ = (
584
+ sigma_t / sigma_prev_0 * x
585
+ - alpha_t * h_phi_1 * model_prev_0
586
+ )
587
+ # now predictor
588
+ x_t = x_t_
589
+ if len(D1s) > 0:
590
+ # compute the residuals for predictor
591
+ for k in range(K - 1):
592
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
593
+ # now corrector
594
+ if use_corrector:
595
+ model_t = self.model_fn(x_t, t)
596
+ D1_t = (model_t - model_prev_0)
597
+ x_t = x_t_
598
+ k = 0
599
+ for k in range(K - 1):
600
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
601
+ x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
602
+ else:
603
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
604
+ x_t_ = (
605
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
606
+ - (sigma_t * h_phi_1) * model_prev_0
607
+ )
608
+ # now predictor
609
+ x_t = x_t_
610
+ if len(D1s) > 0:
611
+ # compute the residuals for predictor
612
+ for k in range(K - 1):
613
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
614
+ # now corrector
615
+ if use_corrector:
616
+ model_t = self.model_fn(x_t, t)
617
+ D1_t = (model_t - model_prev_0)
618
+ x_t = x_t_
619
+ k = 0
620
+ for k in range(K - 1):
621
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
622
+ x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
623
+ return x_t, model_t
624
+
625
+ def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
626
+ #print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
627
+ ns = self.noise_schedule
628
+ assert order <= len(model_prev_list)
629
+ dims = x.dim()
630
+
631
+ # first compute rks
632
+ t_prev_0 = t_prev_list[-1]
633
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
634
+ lambda_t = ns.marginal_lambda(t)
635
+ model_prev_0 = model_prev_list[-1]
636
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
637
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
638
+ alpha_t = torch.exp(log_alpha_t)
639
+
640
+ h = lambda_t - lambda_prev_0
641
+
642
+ rks = []
643
+ D1s = []
644
+ for i in range(1, order):
645
+ t_prev_i = t_prev_list[-(i + 1)]
646
+ model_prev_i = model_prev_list[-(i + 1)]
647
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
648
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
649
+ rks.append(rk)
650
+ D1s.append((model_prev_i - model_prev_0) / rk)
651
+
652
+ rks.append(1.)
653
+ rks = torch.tensor(rks, device=x.device)
654
+
655
+ R = []
656
+ b = []
657
+
658
+ hh = -h[0] if self.predict_x0 else h[0]
659
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
660
+ h_phi_k = h_phi_1 / hh - 1
661
+
662
+ factorial_i = 1
663
+
664
+ if self.variant == 'bh1':
665
+ B_h = hh
666
+ elif self.variant == 'bh2':
667
+ B_h = torch.expm1(hh)
668
+ else:
669
+ raise NotImplementedError()
670
+
671
+ for i in range(1, order + 1):
672
+ R.append(torch.pow(rks, i - 1))
673
+ b.append(h_phi_k * factorial_i / B_h)
674
+ factorial_i *= (i + 1)
675
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
676
+
677
+ R = torch.stack(R)
678
+ b = torch.tensor(b, device=x.device)
679
+
680
+ # now predictor
681
+ use_predictor = len(D1s) > 0 and x_t is None
682
+ if len(D1s) > 0:
683
+ D1s = torch.stack(D1s, dim=1) # (B, K)
684
+ if x_t is None:
685
+ # for order 2, we use a simplified version
686
+ if order == 2:
687
+ rhos_p = torch.tensor([0.5], device=b.device)
688
+ else:
689
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
690
+ else:
691
+ D1s = None
692
+
693
+ if use_corrector:
694
+ #print('using corrector')
695
+ # for order 1, we use a simplified version
696
+ if order == 1:
697
+ rhos_c = torch.tensor([0.5], device=b.device)
698
+ else:
699
+ rhos_c = torch.linalg.solve(R, b)
700
+
701
+ model_t = None
702
+ if self.predict_x0:
703
+ x_t_ = (
704
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
705
+ - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
706
+ )
707
+
708
+ if x_t is None:
709
+ if use_predictor:
710
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
711
+ else:
712
+ pred_res = 0
713
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
714
+
715
+ if use_corrector:
716
+ model_t = self.model_fn(x_t, t)
717
+ if D1s is not None:
718
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
719
+ else:
720
+ corr_res = 0
721
+ D1_t = (model_t - model_prev_0)
722
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
723
+ else:
724
+ x_t_ = (
725
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
726
+ - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
727
+ )
728
+ if x_t is None:
729
+ if use_predictor:
730
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
731
+ else:
732
+ pred_res = 0
733
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
734
+
735
+ if use_corrector:
736
+ model_t = self.model_fn(x_t, t)
737
+ if D1s is not None:
738
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
739
+ else:
740
+ corr_res = 0
741
+ D1_t = (model_t - model_prev_0)
742
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
743
+ return x_t, model_t
744
+
745
+
746
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
747
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
748
+ atol=0.0078, rtol=0.05, corrector=False,
749
+ ):
750
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
751
+ t_T = self.noise_schedule.T if t_start is None else t_start
752
+ device = x.device
753
+ if method == 'multistep':
754
+ assert steps >= order, "UniPC order must be < sampling steps"
755
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
756
+ #print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
757
+ assert timesteps.shape[0] - 1 == steps
758
+ with torch.no_grad():
759
+ vec_t = timesteps[0].expand((x.shape[0]))
760
+ model_prev_list = [self.model_fn(x, vec_t)]
761
+ t_prev_list = [vec_t]
762
+ with tqdm.tqdm(total=steps) as pbar:
763
+ # Init the first `order` values by lower order multistep DPM-Solver.
764
+ for init_order in range(1, order):
765
+ vec_t = timesteps[init_order].expand(x.shape[0])
766
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
767
+ if model_x is None:
768
+ model_x = self.model_fn(x, vec_t)
769
+ if self.after_update is not None:
770
+ self.after_update(x, model_x)
771
+ model_prev_list.append(model_x)
772
+ t_prev_list.append(vec_t)
773
+ pbar.update()
774
+
775
+ for step in range(order, steps + 1):
776
+ vec_t = timesteps[step].expand(x.shape[0])
777
+ if lower_order_final:
778
+ step_order = min(order, steps + 1 - step)
779
+ else:
780
+ step_order = order
781
+ #print('this step order:', step_order)
782
+ if step == steps:
783
+ #print('do not run corrector at the last step')
784
+ use_corrector = False
785
+ else:
786
+ use_corrector = True
787
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
788
+ if self.after_update is not None:
789
+ self.after_update(x, model_x)
790
+ for i in range(order - 1):
791
+ t_prev_list[i] = t_prev_list[i + 1]
792
+ model_prev_list[i] = model_prev_list[i + 1]
793
+ t_prev_list[-1] = vec_t
794
+ # We do not need to evaluate the final model value.
795
+ if step < steps:
796
+ if model_x is None:
797
+ model_x = self.model_fn(x, vec_t)
798
+ model_prev_list[-1] = model_x
799
+ pbar.update()
800
+ else:
801
+ raise NotImplementedError()
802
+ if denoise_to_zero:
803
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
804
+ return x
805
+
806
+
807
+ #############################################################
808
+ # other utility functions
809
+ #############################################################
810
+
811
+ def interpolate_fn(x, xp, yp):
812
+ """
813
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
814
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
815
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
816
+
817
+ Args:
818
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
819
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
820
+ yp: PyTorch tensor with shape [C, K].
821
+ Returns:
822
+ The function values f(x), with shape [N, C].
823
+ """
824
+ N, K = x.shape[0], xp.shape[1]
825
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
826
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
827
+ x_idx = torch.argmin(x_indices, dim=2)
828
+ cand_start_idx = x_idx - 1
829
+ start_idx = torch.where(
830
+ torch.eq(x_idx, 0),
831
+ torch.tensor(1, device=x.device),
832
+ torch.where(
833
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
834
+ ),
835
+ )
836
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
837
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
838
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
839
+ start_idx2 = torch.where(
840
+ torch.eq(x_idx, 0),
841
+ torch.tensor(0, device=x.device),
842
+ torch.where(
843
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
844
+ ),
845
+ )
846
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
847
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
848
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
849
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
850
+ return cand
851
+
852
+
853
+ def expand_dims(v, dims):
854
+ """
855
+ Expand the tensor `v` to the dim `dims`.
856
+
857
+ Args:
858
+ `v`: a PyTorch tensor with shape [N].
859
+ `dim`: a `int`.
860
+ Returns:
861
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
862
+ """
863
+ return v[(...,) + (None,)*(dims - 1)]
modules/ngrok.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ngrok
2
+
3
+ # Connect to ngrok for ingress
4
+ def connect(token, port, options):
5
+ account = None
6
+ if token is None:
7
+ token = 'None'
8
+ else:
9
+ if ':' in token:
10
+ # token = authtoken:username:password
11
+ token, username, password = token.split(':', 2)
12
+ account = f"{username}:{password}"
13
+
14
+ # For all options see: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py
15
+ if not options.get('authtoken_from_env'):
16
+ options['authtoken'] = token
17
+ if account:
18
+ options['basic_auth'] = account
19
+ if not options.get('session_metadata'):
20
+ options['session_metadata'] = 'stable-diffusion-webui'
21
+
22
+
23
+ try:
24
+ public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url()
25
+ except Exception as e:
26
+ print(f'Invalid ngrok authtoken? ngrok connection aborted due to: {e}\n'
27
+ f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
28
+ else:
29
+ print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
30
+ 'You can use this link after the launch is complete.')
modules/paths.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
4
+
5
+ import modules.safe # noqa: F401
6
+
7
+
8
+ def mute_sdxl_imports():
9
+ """create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
10
+
11
+ class Dummy:
12
+ pass
13
+
14
+ module = Dummy()
15
+ module.LPIPS = None
16
+ sys.modules['taming.modules.losses.lpips'] = module
17
+
18
+ module = Dummy()
19
+ module.StableDataModuleFromConfig = None
20
+ sys.modules['sgm.data'] = module
21
+
22
+
23
+ # data_path = cmd_opts_pre.data
24
+ sys.path.insert(0, script_path)
25
+
26
+ # search for directory of stable diffusion in following places
27
+ sd_path = None
28
+ possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
29
+ for possible_sd_path in possible_sd_paths:
30
+ if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
31
+ sd_path = os.path.abspath(possible_sd_path)
32
+ break
33
+
34
+ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
35
+
36
+ mute_sdxl_imports()
37
+
38
+ path_dirs = [
39
+ (sd_path, 'ldm', 'Stable Diffusion', []),
40
+ (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
41
+ (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
42
+ (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
43
+ (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
44
+ ]
45
+
46
+ paths = {}
47
+
48
+ for d, must_exist, what, options in path_dirs:
49
+ must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
50
+ if not os.path.exists(must_exist_path):
51
+ print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
52
+ else:
53
+ d = os.path.abspath(d)
54
+ if "atstart" in options:
55
+ sys.path.insert(0, d)
56
+ elif "sgm" in options:
57
+ # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
58
+ # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
59
+
60
+ sys.path.insert(0, d)
61
+ import sgm # noqa: F401
62
+ sys.path.pop(0)
63
+ else:
64
+ sys.path.append(d)
65
+ paths[what] = d
modules/paths_internal.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py"""
2
+
3
+ import argparse
4
+ import os
5
+ import sys
6
+ import shlex
7
+
8
+ commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
9
+ sys.argv += shlex.split(commandline_args)
10
+
11
+ modules_path = os.path.dirname(os.path.realpath(__file__))
12
+ script_path = os.path.dirname(modules_path)
13
+
14
+ sd_configs_path = os.path.join(script_path, "configs")
15
+ sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
16
+ sd_model_file = os.path.join(script_path, 'model.ckpt')
17
+ default_sd_model_file = sd_model_file
18
+
19
+ # Parse the --data-dir flag first so we can use it as a base for our other argument default values
20
+ parser_pre = argparse.ArgumentParser(add_help=False)
21
+ parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
22
+ cmd_opts_pre = parser_pre.parse_known_args()[0]
23
+
24
+ data_path = cmd_opts_pre.data_dir
25
+
26
+ models_path = os.path.join(data_path, "models")
27
+ extensions_dir = os.path.join(data_path, "extensions")
28
+ extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
29
+ config_states_dir = os.path.join(script_path, "config_states")
30
+
31
+ roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
modules/postprocessing.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from PIL import Image
4
+
5
+ from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
6
+ from modules.shared import opts
7
+
8
+
9
+ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
10
+ devices.torch_gc()
11
+
12
+ shared.state.begin(job="extras")
13
+
14
+ image_data = []
15
+ image_names = []
16
+ outputs = []
17
+
18
+ if extras_mode == 1:
19
+ for img in image_folder:
20
+ if isinstance(img, Image.Image):
21
+ image = img
22
+ fn = ''
23
+ else:
24
+ image = Image.open(os.path.abspath(img.name))
25
+ fn = os.path.splitext(img.orig_name)[0]
26
+ image_data.append(image)
27
+ image_names.append(fn)
28
+ elif extras_mode == 2:
29
+ assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
30
+ assert input_dir, 'input directory not selected'
31
+
32
+ image_list = shared.listfiles(input_dir)
33
+ for filename in image_list:
34
+ try:
35
+ image = Image.open(filename)
36
+ except Exception:
37
+ continue
38
+ image_data.append(image)
39
+ image_names.append(filename)
40
+ else:
41
+ assert image, 'image not selected'
42
+
43
+ image_data.append(image)
44
+ image_names.append(None)
45
+
46
+ if extras_mode == 2 and output_dir != '':
47
+ outpath = output_dir
48
+ else:
49
+ outpath = opts.outdir_samples or opts.outdir_extras_samples
50
+
51
+ infotext = ''
52
+
53
+ for image, name in zip(image_data, image_names):
54
+ shared.state.textinfo = name
55
+
56
+ parameters, existing_pnginfo = images.read_info_from_image(image)
57
+ if parameters:
58
+ existing_pnginfo["parameters"] = parameters
59
+
60
+ pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
61
+
62
+ scripts.scripts_postproc.run(pp, args)
63
+
64
+ if opts.use_original_name_batch and name is not None:
65
+ basename = os.path.splitext(os.path.basename(name))[0]
66
+ else:
67
+ basename = ''
68
+
69
+ infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
70
+
71
+ if opts.enable_pnginfo:
72
+ pp.image.info = existing_pnginfo
73
+ pp.image.info["postprocessing"] = infotext
74
+
75
+ if save_output:
76
+ images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
77
+
78
+ if extras_mode != 2 or show_extras_results:
79
+ outputs.append(pp.image)
80
+
81
+ devices.torch_gc()
82
+
83
+ return outputs, ui_common.plaintext_to_html(infotext), ''
84
+
85
+
86
+ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
87
+ """old handler for API"""
88
+
89
+ args = scripts.scripts_postproc.create_args_for_run({
90
+ "Upscale": {
91
+ "upscale_mode": resize_mode,
92
+ "upscale_by": upscaling_resize,
93
+ "upscale_to_width": upscaling_resize_w,
94
+ "upscale_to_height": upscaling_resize_h,
95
+ "upscale_crop": upscaling_crop,
96
+ "upscaler_1_name": extras_upscaler_1,
97
+ "upscaler_2_name": extras_upscaler_2,
98
+ "upscaler_2_visibility": extras_upscaler_2_visibility,
99
+ },
100
+ "GFPGAN": {
101
+ "gfpgan_visibility": gfpgan_visibility,
102
+ },
103
+ "CodeFormer": {
104
+ "codeformer_visibility": codeformer_visibility,
105
+ "codeformer_weight": codeformer_weight,
106
+ },
107
+ })
108
+
109
+ return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
modules/processing.py ADDED
@@ -0,0 +1,1405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import math
4
+ import os
5
+ import sys
6
+ import hashlib
7
+
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image, ImageOps
11
+ import random
12
+ import cv2
13
+ from skimage import exposure
14
+ from typing import Any, Dict, List
15
+
16
+ import modules.sd_hijack
17
+ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
18
+ from modules.sd_hijack import model_hijack
19
+ from modules.shared import opts, cmd_opts, state
20
+ import modules.shared as shared
21
+ import modules.paths as paths
22
+ import modules.face_restoration
23
+ import modules.images as images
24
+ import modules.styles
25
+ import modules.sd_models as sd_models
26
+ import modules.sd_vae as sd_vae
27
+ from ldm.data.util import AddMiDaS
28
+ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
29
+
30
+ from einops import repeat, rearrange
31
+ from blendmodes.blend import blendLayers, BlendType
32
+
33
+
34
+ # some of those options should not be changed at all because they would break the model, so I removed them from options.
35
+ opt_C = 4
36
+ opt_f = 8
37
+
38
+
39
+ def setup_color_correction(image):
40
+ logging.info("Calibrating color correction.")
41
+ correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
42
+ return correction_target
43
+
44
+
45
+ def apply_color_correction(correction, original_image):
46
+ logging.info("Applying color correction.")
47
+ image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
48
+ cv2.cvtColor(
49
+ np.asarray(original_image),
50
+ cv2.COLOR_RGB2LAB
51
+ ),
52
+ correction,
53
+ channel_axis=2
54
+ ), cv2.COLOR_LAB2RGB).astype("uint8"))
55
+
56
+ image = blendLayers(image, original_image, BlendType.LUMINOSITY)
57
+
58
+ return image
59
+
60
+
61
+ def apply_overlay(image, paste_loc, index, overlays):
62
+ if overlays is None or index >= len(overlays):
63
+ return image
64
+
65
+ overlay = overlays[index]
66
+
67
+ if paste_loc is not None:
68
+ x, y, w, h = paste_loc
69
+ base_image = Image.new('RGBA', (overlay.width, overlay.height))
70
+ image = images.resize_image(1, image, w, h)
71
+ base_image.paste(image, (x, y))
72
+ image = base_image
73
+
74
+ image = image.convert('RGBA')
75
+ image.alpha_composite(overlay)
76
+ image = image.convert('RGB')
77
+
78
+ return image
79
+
80
+
81
+ def txt2img_image_conditioning(sd_model, x, width, height):
82
+ if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
83
+
84
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
85
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
86
+ image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
87
+
88
+ # Add the fake full 1s mask to the first dimension.
89
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
90
+ image_conditioning = image_conditioning.to(x.dtype)
91
+
92
+ return image_conditioning
93
+
94
+ elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
95
+
96
+ return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
97
+
98
+ else:
99
+ # Dummy zero conditioning if we're not using inpainting or unclip models.
100
+ # Still takes up a bit of memory, but no encoder call.
101
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
102
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
103
+
104
+
105
+ class StableDiffusionProcessing:
106
+ """
107
+ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
108
+ """
109
+ cached_uc = [None, None]
110
+ cached_c = [None, None]
111
+
112
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
113
+ if sampler_index is not None:
114
+ print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
115
+
116
+ self.outpath_samples: str = outpath_samples
117
+ self.outpath_grids: str = outpath_grids
118
+ self.prompt: str = prompt
119
+ self.prompt_for_display: str = None
120
+ self.negative_prompt: str = (negative_prompt or "")
121
+ self.styles: list = styles or []
122
+ self.seed: int = seed
123
+ self.subseed: int = subseed
124
+ self.subseed_strength: float = subseed_strength
125
+ self.seed_resize_from_h: int = seed_resize_from_h
126
+ self.seed_resize_from_w: int = seed_resize_from_w
127
+ self.sampler_name: str = sampler_name
128
+ self.batch_size: int = batch_size
129
+ self.n_iter: int = n_iter
130
+ self.steps: int = steps
131
+ self.cfg_scale: float = cfg_scale
132
+ self.width: int = width
133
+ self.height: int = height
134
+ self.restore_faces: bool = restore_faces
135
+ self.tiling: bool = tiling
136
+ self.do_not_save_samples: bool = do_not_save_samples
137
+ self.do_not_save_grid: bool = do_not_save_grid
138
+ self.extra_generation_params: dict = extra_generation_params or {}
139
+ self.overlay_images = overlay_images
140
+ self.eta = eta
141
+ self.do_not_reload_embeddings = do_not_reload_embeddings
142
+ self.paste_to = None
143
+ self.color_corrections = None
144
+ self.denoising_strength: float = denoising_strength
145
+ self.sampler_noise_scheduler_override = None
146
+ self.ddim_discretize = ddim_discretize or opts.ddim_discretize
147
+ self.s_min_uncond = s_min_uncond or opts.s_min_uncond
148
+ self.s_churn = s_churn or opts.s_churn
149
+ self.s_tmin = s_tmin or opts.s_tmin
150
+ self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
151
+ self.s_noise = s_noise or opts.s_noise
152
+ self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
153
+ self.override_settings_restore_afterwards = override_settings_restore_afterwards
154
+ self.is_using_inpainting_conditioning = False
155
+ self.disable_extra_networks = False
156
+ self.token_merging_ratio = 0
157
+ self.token_merging_ratio_hr = 0
158
+
159
+ if not seed_enable_extras:
160
+ self.subseed = -1
161
+ self.subseed_strength = 0
162
+ self.seed_resize_from_h = 0
163
+ self.seed_resize_from_w = 0
164
+
165
+ self.scripts = None
166
+ self.script_args = script_args
167
+ self.all_prompts = None
168
+ self.all_negative_prompts = None
169
+ self.all_seeds = None
170
+ self.all_subseeds = None
171
+ self.iteration = 0
172
+ self.is_hr_pass = False
173
+ self.sampler = None
174
+
175
+ self.prompts = None
176
+ self.negative_prompts = None
177
+ self.extra_network_data = None
178
+ self.seeds = None
179
+ self.subseeds = None
180
+
181
+ self.step_multiplier = 1
182
+ self.cached_uc = StableDiffusionProcessing.cached_uc
183
+ self.cached_c = StableDiffusionProcessing.cached_c
184
+ self.uc = None
185
+ self.c = None
186
+
187
+ self.user = None
188
+
189
+ @property
190
+ def sd_model(self):
191
+ return shared.sd_model
192
+
193
+ def txt2img_image_conditioning(self, x, width=None, height=None):
194
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
195
+
196
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
197
+
198
+ def depth2img_image_conditioning(self, source_image):
199
+ # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
200
+ transformer = AddMiDaS(model_type="dpt_hybrid")
201
+ transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
202
+ midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
203
+ midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
204
+
205
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
206
+ conditioning = torch.nn.functional.interpolate(
207
+ self.sd_model.depth_model(midas_in),
208
+ size=conditioning_image.shape[2:],
209
+ mode="bicubic",
210
+ align_corners=False,
211
+ )
212
+
213
+ (depth_min, depth_max) = torch.aminmax(conditioning)
214
+ conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
215
+ return conditioning
216
+
217
+ def edit_image_conditioning(self, source_image):
218
+ conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
219
+
220
+ return conditioning_image
221
+
222
+ def unclip_image_conditioning(self, source_image):
223
+ c_adm = self.sd_model.embedder(source_image)
224
+ if self.sd_model.noise_augmentor is not None:
225
+ noise_level = 0 # TODO: Allow other noise levels?
226
+ c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
227
+ c_adm = torch.cat((c_adm, noise_level_emb), 1)
228
+ return c_adm
229
+
230
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
231
+ self.is_using_inpainting_conditioning = True
232
+
233
+ # Handle the different mask inputs
234
+ if image_mask is not None:
235
+ if torch.is_tensor(image_mask):
236
+ conditioning_mask = image_mask
237
+ else:
238
+ conditioning_mask = np.array(image_mask.convert("L"))
239
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
240
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
241
+
242
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
243
+ conditioning_mask = torch.round(conditioning_mask)
244
+ else:
245
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
246
+
247
+ # Create another latent image, this time with a masked version of the original input.
248
+ # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
249
+ conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
250
+ conditioning_image = torch.lerp(
251
+ source_image,
252
+ source_image * (1.0 - conditioning_mask),
253
+ getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
254
+ )
255
+
256
+ # Encode the new masked image using first stage of network.
257
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
258
+
259
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
260
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
261
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
262
+ image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
263
+ image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
264
+
265
+ return image_conditioning
266
+
267
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
268
+ source_image = devices.cond_cast_float(source_image)
269
+
270
+ # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
271
+ # identify itself with a field common to all models. The conditioning_key is also hybrid.
272
+ if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
273
+ return self.depth2img_image_conditioning(source_image)
274
+
275
+ if self.sd_model.cond_stage_key == "edit":
276
+ return self.edit_image_conditioning(source_image)
277
+
278
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
279
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
280
+
281
+ if self.sampler.conditioning_key == "crossattn-adm":
282
+ return self.unclip_image_conditioning(source_image)
283
+
284
+ # Dummy zero conditioning if we're not using inpainting or depth model.
285
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
286
+
287
+ def init(self, all_prompts, all_seeds, all_subseeds):
288
+ pass
289
+
290
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
291
+ raise NotImplementedError()
292
+
293
+ def close(self):
294
+ self.sampler = None
295
+ self.c = None
296
+ self.uc = None
297
+ if not opts.experimental_persistent_cond_cache:
298
+ StableDiffusionProcessing.cached_c = [None, None]
299
+ StableDiffusionProcessing.cached_uc = [None, None]
300
+
301
+ def get_token_merging_ratio(self, for_hr=False):
302
+ if for_hr:
303
+ return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
304
+
305
+ return self.token_merging_ratio or opts.token_merging_ratio
306
+
307
+ def setup_prompts(self):
308
+ if type(self.prompt) == list:
309
+ self.all_prompts = self.prompt
310
+ else:
311
+ self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
312
+
313
+ if type(self.negative_prompt) == list:
314
+ self.all_negative_prompts = self.negative_prompt
315
+ else:
316
+ self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]
317
+
318
+ self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
319
+ self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
320
+
321
+ def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
322
+ """
323
+ Returns the result of calling function(shared.sd_model, required_prompts, steps)
324
+ using a cache to store the result if the same arguments have been used before.
325
+
326
+ cache is an array containing two elements. The first element is a tuple
327
+ representing the previously used arguments, or None if no arguments
328
+ have been used before. The second element is where the previously
329
+ computed result is stored.
330
+
331
+ caches is a list with items described above.
332
+ """
333
+
334
+ cached_params = (
335
+ required_prompts,
336
+ steps,
337
+ opts.CLIP_stop_at_last_layers,
338
+ shared.sd_model.sd_checkpoint_info,
339
+ extra_network_data,
340
+ opts.sdxl_crop_left,
341
+ opts.sdxl_crop_top,
342
+ self.width,
343
+ self.height,
344
+ )
345
+
346
+ for cache in caches:
347
+ if cache[0] is not None and cached_params == cache[0]:
348
+ return cache[1]
349
+
350
+ cache = caches[0]
351
+
352
+ with devices.autocast():
353
+ cache[1] = function(shared.sd_model, required_prompts, steps)
354
+
355
+ cache[0] = cached_params
356
+ return cache[1]
357
+
358
+ def setup_conds(self):
359
+ prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
360
+ negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
361
+
362
+ sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
363
+ self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
364
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
365
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
366
+
367
+ def parse_extra_network_prompts(self):
368
+ self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
369
+
370
+
371
+ class Processed:
372
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
373
+ self.images = images_list
374
+ self.prompt = p.prompt
375
+ self.negative_prompt = p.negative_prompt
376
+ self.seed = seed
377
+ self.subseed = subseed
378
+ self.subseed_strength = p.subseed_strength
379
+ self.info = info
380
+ self.comments = comments
381
+ self.width = p.width
382
+ self.height = p.height
383
+ self.sampler_name = p.sampler_name
384
+ self.cfg_scale = p.cfg_scale
385
+ self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
386
+ self.steps = p.steps
387
+ self.batch_size = p.batch_size
388
+ self.restore_faces = p.restore_faces
389
+ self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
390
+ self.sd_model_hash = shared.sd_model.sd_model_hash
391
+ self.seed_resize_from_w = p.seed_resize_from_w
392
+ self.seed_resize_from_h = p.seed_resize_from_h
393
+ self.denoising_strength = getattr(p, 'denoising_strength', None)
394
+ self.extra_generation_params = p.extra_generation_params
395
+ self.index_of_first_image = index_of_first_image
396
+ self.styles = p.styles
397
+ self.job_timestamp = state.job_timestamp
398
+ self.clip_skip = opts.CLIP_stop_at_last_layers
399
+ self.token_merging_ratio = p.token_merging_ratio
400
+ self.token_merging_ratio_hr = p.token_merging_ratio_hr
401
+
402
+ self.eta = p.eta
403
+ self.ddim_discretize = p.ddim_discretize
404
+ self.s_churn = p.s_churn
405
+ self.s_tmin = p.s_tmin
406
+ self.s_tmax = p.s_tmax
407
+ self.s_noise = p.s_noise
408
+ self.s_min_uncond = p.s_min_uncond
409
+ self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
410
+ self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
411
+ self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
412
+ self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
413
+ self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
414
+ self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
415
+
416
+ self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
417
+ self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
418
+ self.all_seeds = all_seeds or p.all_seeds or [self.seed]
419
+ self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
420
+ self.infotexts = infotexts or [info]
421
+
422
+ def js(self):
423
+ obj = {
424
+ "prompt": self.all_prompts[0],
425
+ "all_prompts": self.all_prompts,
426
+ "negative_prompt": self.all_negative_prompts[0],
427
+ "all_negative_prompts": self.all_negative_prompts,
428
+ "seed": self.seed,
429
+ "all_seeds": self.all_seeds,
430
+ "subseed": self.subseed,
431
+ "all_subseeds": self.all_subseeds,
432
+ "subseed_strength": self.subseed_strength,
433
+ "width": self.width,
434
+ "height": self.height,
435
+ "sampler_name": self.sampler_name,
436
+ "cfg_scale": self.cfg_scale,
437
+ "steps": self.steps,
438
+ "batch_size": self.batch_size,
439
+ "restore_faces": self.restore_faces,
440
+ "face_restoration_model": self.face_restoration_model,
441
+ "sd_model_hash": self.sd_model_hash,
442
+ "seed_resize_from_w": self.seed_resize_from_w,
443
+ "seed_resize_from_h": self.seed_resize_from_h,
444
+ "denoising_strength": self.denoising_strength,
445
+ "extra_generation_params": self.extra_generation_params,
446
+ "index_of_first_image": self.index_of_first_image,
447
+ "infotexts": self.infotexts,
448
+ "styles": self.styles,
449
+ "job_timestamp": self.job_timestamp,
450
+ "clip_skip": self.clip_skip,
451
+ "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
452
+ }
453
+
454
+ return json.dumps(obj)
455
+
456
+ def infotext(self, p: StableDiffusionProcessing, index):
457
+ return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
458
+
459
+ def get_token_merging_ratio(self, for_hr=False):
460
+ return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
461
+
462
+
463
+ # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
464
+ def slerp(val, low, high):
465
+ low_norm = low/torch.norm(low, dim=1, keepdim=True)
466
+ high_norm = high/torch.norm(high, dim=1, keepdim=True)
467
+ dot = (low_norm*high_norm).sum(1)
468
+
469
+ if dot.mean() > 0.9995:
470
+ return low * val + high * (1 - val)
471
+
472
+ omega = torch.acos(dot)
473
+ so = torch.sin(omega)
474
+ res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
475
+ return res
476
+
477
+
478
+ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
479
+ eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
480
+ xs = []
481
+
482
+ # if we have multiple seeds, this means we are working with batch size>1; this then
483
+ # enables the generation of additional tensors with noise that the sampler will use during its processing.
484
+ # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
485
+ # produce the same images as with two batches [100], [101].
486
+ if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
487
+ sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
488
+ else:
489
+ sampler_noises = None
490
+
491
+ for i, seed in enumerate(seeds):
492
+ noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
493
+
494
+ subnoise = None
495
+ if subseeds is not None:
496
+ subseed = 0 if i >= len(subseeds) else subseeds[i]
497
+
498
+ subnoise = devices.randn(subseed, noise_shape)
499
+
500
+ # randn results depend on device; gpu and cpu get different results for same seed;
501
+ # the way I see it, it's better to do this on CPU, so that everyone gets same result;
502
+ # but the original script had it like this, so I do not dare change it for now because
503
+ # it will break everyone's seeds.
504
+ noise = devices.randn(seed, noise_shape)
505
+
506
+ if subnoise is not None:
507
+ noise = slerp(subseed_strength, noise, subnoise)
508
+
509
+ if noise_shape != shape:
510
+ x = devices.randn(seed, shape)
511
+ dx = (shape[2] - noise_shape[2]) // 2
512
+ dy = (shape[1] - noise_shape[1]) // 2
513
+ w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
514
+ h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
515
+ tx = 0 if dx < 0 else dx
516
+ ty = 0 if dy < 0 else dy
517
+ dx = max(-dx, 0)
518
+ dy = max(-dy, 0)
519
+
520
+ x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
521
+ noise = x
522
+
523
+ if sampler_noises is not None:
524
+ cnt = p.sampler.number_of_needed_noises(p)
525
+
526
+ if eta_noise_seed_delta > 0:
527
+ torch.manual_seed(seed + eta_noise_seed_delta)
528
+
529
+ for j in range(cnt):
530
+ sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
531
+
532
+ xs.append(noise)
533
+
534
+ if sampler_noises is not None:
535
+ p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
536
+
537
+ x = torch.stack(xs).to(shared.device)
538
+ return x
539
+
540
+
541
+ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
542
+ samples = []
543
+
544
+ for i in range(batch.shape[0]):
545
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
546
+
547
+ if check_for_nans:
548
+ try:
549
+ devices.test_for_nans(sample, "vae")
550
+ except devices.NansException as e:
551
+ if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
552
+ raise e
553
+
554
+ errors.print_error_explanation(
555
+ "A tensor with all NaNs was produced in VAE.\n"
556
+ "Web UI will now convert VAE into 32-bit float and retry.\n"
557
+ "To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"
558
+ "To always start with 32-bit VAE, use --no-half-vae commandline flag."
559
+ )
560
+
561
+ devices.dtype_vae = torch.float32
562
+ model.first_stage_model.to(devices.dtype_vae)
563
+ batch = batch.to(devices.dtype_vae)
564
+
565
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
566
+
567
+ if target_device is not None:
568
+ sample = sample.to(target_device)
569
+
570
+ samples.append(sample)
571
+
572
+ return samples
573
+
574
+
575
+ def decode_first_stage(model, x):
576
+ x = model.decode_first_stage(x.to(devices.dtype_vae))
577
+
578
+ return x
579
+
580
+
581
+ def get_fixed_seed(seed):
582
+ if seed is None or seed == '' or seed == -1:
583
+ return int(random.randrange(4294967294))
584
+
585
+ return seed
586
+
587
+
588
+ def fix_seed(p):
589
+ p.seed = get_fixed_seed(p.seed)
590
+ p.subseed = get_fixed_seed(p.subseed)
591
+
592
+
593
+ def program_version():
594
+ import launch
595
+
596
+ res = launch.git_tag()
597
+ if res == "<none>":
598
+ res = None
599
+
600
+ return res
601
+
602
+
603
+ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
604
+ index = position_in_batch + iteration * p.batch_size
605
+
606
+ clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
607
+ enable_hr = getattr(p, 'enable_hr', False)
608
+ token_merging_ratio = p.get_token_merging_ratio()
609
+ token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
610
+
611
+ uses_ensd = opts.eta_noise_seed_delta != 0
612
+ if uses_ensd:
613
+ uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
614
+
615
+ generation_params = {
616
+ "Steps": p.steps,
617
+ "Sampler": p.sampler_name,
618
+ "CFG scale": p.cfg_scale,
619
+ "Image CFG scale": getattr(p, 'image_cfg_scale', None),
620
+ "Seed": all_seeds[index],
621
+ "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
622
+ "Size": f"{p.width}x{p.height}",
623
+ "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
624
+ "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
625
+ "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
626
+ "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
627
+ "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
628
+ "Denoising strength": getattr(p, 'denoising_strength', None),
629
+ "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
630
+ "Clip skip": None if clip_skip <= 1 else clip_skip,
631
+ "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
632
+ "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
633
+ "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
634
+ "Init image hash": getattr(p, 'init_img_hash', None),
635
+ "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
636
+ "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
637
+ **p.extra_generation_params,
638
+ "Version": program_version() if opts.add_version_to_infotext else None,
639
+ "User": p.user if opts.add_user_name_to_info else None,
640
+ }
641
+
642
+ generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
643
+
644
+ prompt_text = p.prompt if use_main_prompt else all_prompts[index]
645
+ negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
646
+
647
+ return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
648
+
649
+
650
+ def process_images(p: StableDiffusionProcessing) -> Processed:
651
+ if p.scripts is not None:
652
+ p.scripts.before_process(p)
653
+
654
+ stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
655
+
656
+ try:
657
+ # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
658
+ if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
659
+ p.override_settings.pop('sd_model_checkpoint', None)
660
+ sd_models.reload_model_weights()
661
+
662
+ for k, v in p.override_settings.items():
663
+ setattr(opts, k, v)
664
+
665
+ if k == 'sd_model_checkpoint':
666
+ sd_models.reload_model_weights()
667
+
668
+ if k == 'sd_vae':
669
+ sd_vae.reload_vae_weights()
670
+
671
+ sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
672
+
673
+ res = process_images_inner(p)
674
+
675
+ finally:
676
+ sd_models.apply_token_merging(p.sd_model, 0)
677
+
678
+ # restore opts to original state
679
+ if p.override_settings_restore_afterwards:
680
+ for k, v in stored_opts.items():
681
+ setattr(opts, k, v)
682
+
683
+ if k == 'sd_vae':
684
+ sd_vae.reload_vae_weights()
685
+
686
+ return res
687
+
688
+
689
+ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
690
+ """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
691
+
692
+ if type(p.prompt) == list:
693
+ assert(len(p.prompt) > 0)
694
+ else:
695
+ assert p.prompt is not None
696
+
697
+ devices.torch_gc()
698
+
699
+ seed = get_fixed_seed(p.seed)
700
+ subseed = get_fixed_seed(p.subseed)
701
+
702
+ modules.sd_hijack.model_hijack.apply_circular(p.tiling)
703
+ modules.sd_hijack.model_hijack.clear_comments()
704
+
705
+ comments = {}
706
+
707
+ p.setup_prompts()
708
+
709
+ if type(seed) == list:
710
+ p.all_seeds = seed
711
+ else:
712
+ p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
713
+
714
+ if type(subseed) == list:
715
+ p.all_subseeds = subseed
716
+ else:
717
+ p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
718
+
719
+ def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
720
+ all_prompts = p.all_prompts[:]
721
+ all_negative_prompts = p.all_negative_prompts[:]
722
+ all_seeds = p.all_seeds[:]
723
+ all_subseeds = p.all_subseeds[:]
724
+
725
+ # apply changes to generation data
726
+ all_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.prompts
727
+ all_negative_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.negative_prompts
728
+ all_seeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.seeds
729
+ all_subseeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.subseeds
730
+
731
+ # update p.all_negative_prompts in case extensions changed the size of the batch
732
+ # create_infotext below uses it
733
+ old_negative_prompts = p.all_negative_prompts
734
+ p.all_negative_prompts = all_negative_prompts
735
+
736
+ try:
737
+ return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
738
+ finally:
739
+ # restore p.all_negative_prompts in case extensions changed the size of the batch
740
+ p.all_negative_prompts = old_negative_prompts
741
+
742
+ if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
743
+ model_hijack.embedding_db.load_textual_inversion_embeddings()
744
+
745
+ if p.scripts is not None:
746
+ p.scripts.process(p)
747
+
748
+ infotexts = []
749
+ output_images = []
750
+
751
+ with torch.no_grad(), p.sd_model.ema_scope():
752
+ with devices.autocast():
753
+ p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
754
+
755
+ # for OSX, loading the model during sampling changes the generated picture, so it is loaded here
756
+ if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
757
+ sd_vae_approx.model()
758
+
759
+ sd_unet.apply_unet()
760
+
761
+ if state.job_count == -1:
762
+ state.job_count = p.n_iter
763
+
764
+ for n in range(p.n_iter):
765
+ p.iteration = n
766
+
767
+ if state.skipped:
768
+ state.skipped = False
769
+
770
+ if state.interrupted:
771
+ break
772
+
773
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
774
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
775
+ p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
776
+ p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
777
+
778
+ if p.scripts is not None:
779
+ p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
780
+
781
+ if len(p.prompts) == 0:
782
+ break
783
+
784
+ p.parse_extra_network_prompts()
785
+
786
+ if not p.disable_extra_networks:
787
+ with devices.autocast():
788
+ extra_networks.activate(p, p.extra_network_data)
789
+
790
+ if p.scripts is not None:
791
+ p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
792
+
793
+ # params.txt should be saved after scripts.process_batch, since the
794
+ # infotext could be modified by that callback
795
+ # Example: a wildcard processed by process_batch sets an extra model
796
+ # strength, which is saved as "Model Strength: 1.0" in the infotext
797
+ if n == 0:
798
+ with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
799
+ processed = Processed(p, [], p.seed, "")
800
+ file.write(processed.infotext(p, 0))
801
+
802
+ p.setup_conds()
803
+
804
+ for comment in model_hijack.comments:
805
+ comments[comment] = 1
806
+
807
+ p.extra_generation_params.update(model_hijack.extra_generation_params)
808
+
809
+ if p.n_iter > 1:
810
+ shared.state.job = f"Batch {n+1} out of {p.n_iter}"
811
+
812
+ with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
813
+ samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
814
+
815
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
816
+ x_samples_ddim = torch.stack(x_samples_ddim).float()
817
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
818
+
819
+ del samples_ddim
820
+
821
+ if lowvram.is_enabled(shared.sd_model):
822
+ lowvram.send_everything_to_cpu()
823
+
824
+ devices.torch_gc()
825
+
826
+ if p.scripts is not None:
827
+ p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
828
+
829
+ postprocess_batch_list_args = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
830
+ p.scripts.postprocess_batch_list(p, postprocess_batch_list_args, batch_number=n)
831
+ x_samples_ddim = postprocess_batch_list_args.images
832
+
833
+ for i, x_sample in enumerate(x_samples_ddim):
834
+ p.batch_index = i
835
+
836
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
837
+ x_sample = x_sample.astype(np.uint8)
838
+
839
+ if p.restore_faces:
840
+ if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
841
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
842
+
843
+ devices.torch_gc()
844
+
845
+ x_sample = modules.face_restoration.restore_faces(x_sample)
846
+ devices.torch_gc()
847
+
848
+ image = Image.fromarray(x_sample)
849
+
850
+ if p.scripts is not None:
851
+ pp = scripts.PostprocessImageArgs(image)
852
+ p.scripts.postprocess_image(p, pp)
853
+ image = pp.image
854
+
855
+ if p.color_corrections is not None and i < len(p.color_corrections):
856
+ if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
857
+ image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
858
+ images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
859
+ image = apply_color_correction(p.color_corrections[i], image)
860
+
861
+ image = apply_overlay(image, p.paste_to, i, p.overlay_images)
862
+
863
+ if opts.samples_save and not p.do_not_save_samples:
864
+ images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p)
865
+
866
+ text = infotext(n, i)
867
+ infotexts.append(text)
868
+ if opts.enable_pnginfo:
869
+ image.info["parameters"] = text
870
+ output_images.append(image)
871
+
872
+ if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
873
+ image_mask = p.mask_for_overlay.convert('RGB')
874
+ image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
875
+
876
+ if opts.save_mask:
877
+ images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
878
+
879
+ if opts.save_mask_composite:
880
+ images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
881
+
882
+ if opts.return_mask:
883
+ output_images.append(image_mask)
884
+
885
+ if opts.return_mask_composite:
886
+ output_images.append(image_mask_composite)
887
+
888
+ del x_samples_ddim
889
+
890
+ devices.torch_gc()
891
+
892
+ state.nextjob()
893
+
894
+ p.color_corrections = None
895
+
896
+ index_of_first_image = 0
897
+ unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
898
+ if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
899
+ grid = images.image_grid(output_images, p.batch_size)
900
+
901
+ if opts.return_grid:
902
+ text = infotext(use_main_prompt=True)
903
+ infotexts.insert(0, text)
904
+ if opts.enable_pnginfo:
905
+ grid.info["parameters"] = text
906
+ output_images.insert(0, grid)
907
+ index_of_first_image = 1
908
+
909
+ if opts.grid_save:
910
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
911
+
912
+ if not p.disable_extra_networks and p.extra_network_data:
913
+ extra_networks.deactivate(p, p.extra_network_data)
914
+
915
+ devices.torch_gc()
916
+
917
+ res = Processed(
918
+ p,
919
+ images_list=output_images,
920
+ seed=p.all_seeds[0],
921
+ info=infotext(),
922
+ comments="".join(f"{comment}\n" for comment in comments),
923
+ subseed=p.all_subseeds[0],
924
+ index_of_first_image=index_of_first_image,
925
+ infotexts=infotexts,
926
+ )
927
+
928
+ if p.scripts is not None:
929
+ p.scripts.postprocess(p, res)
930
+
931
+ return res
932
+
933
+
934
+ def old_hires_fix_first_pass_dimensions(width, height):
935
+ """old algorithm for auto-calculating first pass size"""
936
+
937
+ desired_pixel_count = 512 * 512
938
+ actual_pixel_count = width * height
939
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
940
+ width = math.ceil(scale * width / 64) * 64
941
+ height = math.ceil(scale * height / 64) * 64
942
+
943
+ return width, height
944
+
945
+
946
+ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
947
+ sampler = None
948
+ cached_hr_uc = [None, None]
949
+ cached_hr_c = [None, None]
950
+
951
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
952
+ super().__init__(**kwargs)
953
+ self.enable_hr = enable_hr
954
+ self.denoising_strength = denoising_strength
955
+ self.hr_scale = hr_scale
956
+ self.hr_upscaler = hr_upscaler
957
+ self.hr_second_pass_steps = hr_second_pass_steps
958
+ self.hr_resize_x = hr_resize_x
959
+ self.hr_resize_y = hr_resize_y
960
+ self.hr_upscale_to_x = hr_resize_x
961
+ self.hr_upscale_to_y = hr_resize_y
962
+ self.hr_sampler_name = hr_sampler_name
963
+ self.hr_prompt = hr_prompt
964
+ self.hr_negative_prompt = hr_negative_prompt
965
+ self.all_hr_prompts = None
966
+ self.all_hr_negative_prompts = None
967
+
968
+ if firstphase_width != 0 or firstphase_height != 0:
969
+ self.hr_upscale_to_x = self.width
970
+ self.hr_upscale_to_y = self.height
971
+ self.width = firstphase_width
972
+ self.height = firstphase_height
973
+
974
+ self.truncate_x = 0
975
+ self.truncate_y = 0
976
+ self.applied_old_hires_behavior_to = None
977
+
978
+ self.hr_prompts = None
979
+ self.hr_negative_prompts = None
980
+ self.hr_extra_network_data = None
981
+
982
+ self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
983
+ self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
984
+ self.hr_c = None
985
+ self.hr_uc = None
986
+
987
+ def init(self, all_prompts, all_seeds, all_subseeds):
988
+ if self.enable_hr:
989
+ if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
990
+ self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
991
+
992
+ if tuple(self.hr_prompt) != tuple(self.prompt):
993
+ self.extra_generation_params["Hires prompt"] = self.hr_prompt
994
+
995
+ if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
996
+ self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
997
+
998
+ if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
999
+ self.hr_resize_x = self.width
1000
+ self.hr_resize_y = self.height
1001
+ self.hr_upscale_to_x = self.width
1002
+ self.hr_upscale_to_y = self.height
1003
+
1004
+ self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
1005
+ self.applied_old_hires_behavior_to = (self.width, self.height)
1006
+
1007
+ if self.hr_resize_x == 0 and self.hr_resize_y == 0:
1008
+ self.extra_generation_params["Hires upscale"] = self.hr_scale
1009
+ self.hr_upscale_to_x = int(self.width * self.hr_scale)
1010
+ self.hr_upscale_to_y = int(self.height * self.hr_scale)
1011
+ else:
1012
+ self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
1013
+
1014
+ if self.hr_resize_y == 0:
1015
+ self.hr_upscale_to_x = self.hr_resize_x
1016
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
1017
+ elif self.hr_resize_x == 0:
1018
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
1019
+ self.hr_upscale_to_y = self.hr_resize_y
1020
+ else:
1021
+ target_w = self.hr_resize_x
1022
+ target_h = self.hr_resize_y
1023
+ src_ratio = self.width / self.height
1024
+ dst_ratio = self.hr_resize_x / self.hr_resize_y
1025
+
1026
+ if src_ratio < dst_ratio:
1027
+ self.hr_upscale_to_x = self.hr_resize_x
1028
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
1029
+ else:
1030
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
1031
+ self.hr_upscale_to_y = self.hr_resize_y
1032
+
1033
+ self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
1034
+ self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
1035
+
1036
+ # special case: the user has chosen to do nothing
1037
+ if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
1038
+ self.enable_hr = False
1039
+ self.denoising_strength = None
1040
+ self.extra_generation_params.pop("Hires upscale", None)
1041
+ self.extra_generation_params.pop("Hires resize", None)
1042
+ return
1043
+
1044
+ if not state.processing_has_refined_job_count:
1045
+ if state.job_count == -1:
1046
+ state.job_count = self.n_iter
1047
+
1048
+ shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
1049
+ state.job_count = state.job_count * 2
1050
+ state.processing_has_refined_job_count = True
1051
+
1052
+ if self.hr_second_pass_steps:
1053
+ self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
1054
+
1055
+ if self.hr_upscaler is not None:
1056
+ self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
1057
+
1058
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1059
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1060
+
1061
+ latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
1062
+ if self.enable_hr and latent_scale_mode is None:
1063
+ if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
1064
+ raise Exception(f"could not find upscaler named {self.hr_upscaler}")
1065
+
1066
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1067
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
1068
+
1069
+ if not self.enable_hr:
1070
+ return samples
1071
+
1072
+ self.is_hr_pass = True
1073
+
1074
+ target_width = self.hr_upscale_to_x
1075
+ target_height = self.hr_upscale_to_y
1076
+
1077
+ def save_intermediate(image, index):
1078
+ """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
1079
+
1080
+ if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
1081
+ return
1082
+
1083
+ if not isinstance(image, Image.Image):
1084
+ image = sd_samplers.sample_to_image(image, index, approximation=0)
1085
+
1086
+ info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
1087
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
1088
+
1089
+ if latent_scale_mode is not None:
1090
+ for i in range(samples.shape[0]):
1091
+ save_intermediate(samples, i)
1092
+
1093
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
1094
+
1095
+ # Avoid making the inpainting conditioning unless necessary as
1096
+ # this does need some extra compute to decode / encode the image again.
1097
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
1098
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
1099
+ else:
1100
+ image_conditioning = self.txt2img_image_conditioning(samples)
1101
+ else:
1102
+ decoded_samples = decode_first_stage(self.sd_model, samples)
1103
+ lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1104
+
1105
+ batch_images = []
1106
+ for i, x_sample in enumerate(lowres_samples):
1107
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
1108
+ x_sample = x_sample.astype(np.uint8)
1109
+ image = Image.fromarray(x_sample)
1110
+
1111
+ save_intermediate(image, i)
1112
+
1113
+ image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
1114
+ image = np.array(image).astype(np.float32) / 255.0
1115
+ image = np.moveaxis(image, 2, 0)
1116
+ batch_images.append(image)
1117
+
1118
+ decoded_samples = torch.from_numpy(np.array(batch_images))
1119
+ decoded_samples = decoded_samples.to(shared.device)
1120
+ decoded_samples = 2. * decoded_samples - 1.
1121
+
1122
+ samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
1123
+
1124
+ image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
1125
+
1126
+ shared.state.nextjob()
1127
+
1128
+ img2img_sampler_name = self.hr_sampler_name or self.sampler_name
1129
+
1130
+ if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
1131
+ img2img_sampler_name = 'DDIM'
1132
+
1133
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
1134
+
1135
+ samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
1136
+
1137
+ noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
1138
+
1139
+ # GC now before running the next img2img to prevent running out of memory
1140
+ x = None
1141
+ devices.torch_gc()
1142
+
1143
+ if not self.disable_extra_networks:
1144
+ with devices.autocast():
1145
+ extra_networks.activate(self, self.hr_extra_network_data)
1146
+
1147
+ with devices.autocast():
1148
+ self.calculate_hr_conds()
1149
+
1150
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
1151
+
1152
+ if self.scripts is not None:
1153
+ self.scripts.before_hr(self)
1154
+
1155
+ samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
1156
+
1157
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
1158
+
1159
+ self.is_hr_pass = False
1160
+
1161
+ return samples
1162
+
1163
+ def close(self):
1164
+ super().close()
1165
+ self.hr_c = None
1166
+ self.hr_uc = None
1167
+ if not opts.experimental_persistent_cond_cache:
1168
+ StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
1169
+ StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
1170
+
1171
+ def setup_prompts(self):
1172
+ super().setup_prompts()
1173
+
1174
+ if not self.enable_hr:
1175
+ return
1176
+
1177
+ if self.hr_prompt == '':
1178
+ self.hr_prompt = self.prompt
1179
+
1180
+ if self.hr_negative_prompt == '':
1181
+ self.hr_negative_prompt = self.negative_prompt
1182
+
1183
+ if type(self.hr_prompt) == list:
1184
+ self.all_hr_prompts = self.hr_prompt
1185
+ else:
1186
+ self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
1187
+
1188
+ if type(self.hr_negative_prompt) == list:
1189
+ self.all_hr_negative_prompts = self.hr_negative_prompt
1190
+ else:
1191
+ self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
1192
+
1193
+ self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
1194
+ self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
1195
+
1196
+ def calculate_hr_conds(self):
1197
+ if self.hr_c is not None:
1198
+ return
1199
+
1200
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
1201
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
1202
+
1203
+ def setup_conds(self):
1204
+ super().setup_conds()
1205
+
1206
+ self.hr_uc = None
1207
+ self.hr_c = None
1208
+
1209
+ if self.enable_hr:
1210
+ if shared.opts.hires_fix_use_firstpass_conds:
1211
+ self.calculate_hr_conds()
1212
+
1213
+ elif lowvram.is_enabled(shared.sd_model): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
1214
+ with devices.autocast():
1215
+ extra_networks.activate(self, self.hr_extra_network_data)
1216
+
1217
+ self.calculate_hr_conds()
1218
+
1219
+ with devices.autocast():
1220
+ extra_networks.activate(self, self.extra_network_data)
1221
+
1222
+ def parse_extra_network_prompts(self):
1223
+ res = super().parse_extra_network_prompts()
1224
+
1225
+ if self.enable_hr:
1226
+ self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
1227
+ self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
1228
+
1229
+ self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)
1230
+
1231
+ return res
1232
+
1233
+
1234
+ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
1235
+ sampler = None
1236
+
1237
+ def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
1238
+ super().__init__(**kwargs)
1239
+
1240
+ self.init_images = init_images
1241
+ self.resize_mode: int = resize_mode
1242
+ self.denoising_strength: float = denoising_strength
1243
+ self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
1244
+ self.init_latent = None
1245
+ self.image_mask = mask
1246
+ self.latent_mask = None
1247
+ self.mask_for_overlay = None
1248
+ if mask_blur is not None:
1249
+ mask_blur_x = mask_blur
1250
+ mask_blur_y = mask_blur
1251
+ self.mask_blur_x = mask_blur_x
1252
+ self.mask_blur_y = mask_blur_y
1253
+ self.inpainting_fill = inpainting_fill
1254
+ self.inpaint_full_res = inpaint_full_res
1255
+ self.inpaint_full_res_padding = inpaint_full_res_padding
1256
+ self.inpainting_mask_invert = inpainting_mask_invert
1257
+ self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
1258
+ self.mask = None
1259
+ self.nmask = None
1260
+ self.image_conditioning = None
1261
+
1262
+ def init(self, all_prompts, all_seeds, all_subseeds):
1263
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1264
+ crop_region = None
1265
+
1266
+ image_mask = self.image_mask
1267
+
1268
+ if image_mask is not None:
1269
+ image_mask = image_mask.convert('L')
1270
+
1271
+ if self.inpainting_mask_invert:
1272
+ image_mask = ImageOps.invert(image_mask)
1273
+
1274
+ if self.mask_blur_x > 0:
1275
+ np_mask = np.array(image_mask)
1276
+ kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1
1277
+ np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
1278
+ image_mask = Image.fromarray(np_mask)
1279
+
1280
+ if self.mask_blur_y > 0:
1281
+ np_mask = np.array(image_mask)
1282
+ kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1
1283
+ np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
1284
+ image_mask = Image.fromarray(np_mask)
1285
+
1286
+ if self.inpaint_full_res:
1287
+ self.mask_for_overlay = image_mask
1288
+ mask = image_mask.convert('L')
1289
+ crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
1290
+ crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
1291
+ x1, y1, x2, y2 = crop_region
1292
+
1293
+ mask = mask.crop(crop_region)
1294
+ image_mask = images.resize_image(2, mask, self.width, self.height)
1295
+ self.paste_to = (x1, y1, x2-x1, y2-y1)
1296
+ else:
1297
+ image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
1298
+ np_mask = np.array(image_mask)
1299
+ np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
1300
+ self.mask_for_overlay = Image.fromarray(np_mask)
1301
+
1302
+ self.overlay_images = []
1303
+
1304
+ latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
1305
+
1306
+ add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
1307
+ if add_color_corrections:
1308
+ self.color_corrections = []
1309
+ imgs = []
1310
+ for img in self.init_images:
1311
+
1312
+ # Save init image
1313
+ if opts.save_init_img:
1314
+ self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
1315
+ images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
1316
+
1317
+ image = images.flatten(img, opts.img2img_background_color)
1318
+
1319
+ if crop_region is None and self.resize_mode != 3:
1320
+ image = images.resize_image(self.resize_mode, image, self.width, self.height)
1321
+
1322
+ if image_mask is not None:
1323
+ image_masked = Image.new('RGBa', (image.width, image.height))
1324
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
1325
+
1326
+ self.overlay_images.append(image_masked.convert('RGBA'))
1327
+
1328
+ # crop_region is not None if we are doing inpaint full res
1329
+ if crop_region is not None:
1330
+ image = image.crop(crop_region)
1331
+ image = images.resize_image(2, image, self.width, self.height)
1332
+
1333
+ if image_mask is not None:
1334
+ if self.inpainting_fill != 1:
1335
+ image = masking.fill(image, latent_mask)
1336
+
1337
+ if add_color_corrections:
1338
+ self.color_corrections.append(setup_color_correction(image))
1339
+
1340
+ image = np.array(image).astype(np.float32) / 255.0
1341
+ image = np.moveaxis(image, 2, 0)
1342
+
1343
+ imgs.append(image)
1344
+
1345
+ if len(imgs) == 1:
1346
+ batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
1347
+ if self.overlay_images is not None:
1348
+ self.overlay_images = self.overlay_images * self.batch_size
1349
+
1350
+ if self.color_corrections is not None and len(self.color_corrections) == 1:
1351
+ self.color_corrections = self.color_corrections * self.batch_size
1352
+
1353
+ elif len(imgs) <= self.batch_size:
1354
+ self.batch_size = len(imgs)
1355
+ batch_images = np.array(imgs)
1356
+ else:
1357
+ raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
1358
+
1359
+ image = torch.from_numpy(batch_images)
1360
+ image = 2. * image - 1.
1361
+ image = image.to(shared.device, dtype=devices.dtype_vae)
1362
+
1363
+ self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
1364
+
1365
+ if self.resize_mode == 3:
1366
+ self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
1367
+
1368
+ if image_mask is not None:
1369
+ init_mask = latent_mask
1370
+ latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
1371
+ latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
1372
+ latmask = latmask[0]
1373
+ latmask = np.around(latmask)
1374
+ latmask = np.tile(latmask[None], (4, 1, 1))
1375
+
1376
+ self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
1377
+ self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
1378
+
1379
+ # this needs to be fixed to be done in sample() using actual seeds for batches
1380
+ if self.inpainting_fill == 2:
1381
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
1382
+ elif self.inpainting_fill == 3:
1383
+ self.init_latent = self.init_latent * self.mask
1384
+
1385
+ self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
1386
+
1387
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1388
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1389
+
1390
+ if self.initial_noise_multiplier != 1.0:
1391
+ self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
1392
+ x *= self.initial_noise_multiplier
1393
+
1394
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
1395
+
1396
+ if self.mask is not None:
1397
+ samples = samples * self.nmask + self.init_latent * self.mask
1398
+
1399
+ del x
1400
+ devices.torch_gc()
1401
+
1402
+ return samples
1403
+
1404
+ def get_token_merging_ratio(self, for_hr=False):
1405
+ return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio