aclicheroux commited on
Commit
e0c66e4
1 Parent(s): d9f9915

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CODEOWNERS +1 -0
  2. app.py +137 -0
  3. artists.csv +0 -0
  4. environment-wsl2.yaml +11 -0
  5. javascript/aspectRatioOverlay.js +119 -0
  6. javascript/contextMenus.js +177 -0
  7. javascript/dragdrop.js +86 -0
  8. javascript/edit-attention.js +45 -0
  9. javascript/hints.js +121 -0
  10. javascript/imageMaskFix.js +45 -0
  11. javascript/imageviewer.js +236 -0
  12. javascript/notification.js +49 -0
  13. javascript/progressbar.js +76 -0
  14. javascript/textualInversion.js +8 -0
  15. javascript/ui.js +234 -0
  16. launch.py +169 -0
  17. modules/artists.py +25 -0
  18. modules/bsrgan_model.py +76 -0
  19. modules/bsrgan_model_arch.py +102 -0
  20. modules/codeformer/codeformer_arch.py +278 -0
  21. modules/codeformer/vqgan_arch.py +437 -0
  22. modules/codeformer_model.py +140 -0
  23. modules/deepbooru.py +173 -0
  24. modules/devices.py +72 -0
  25. modules/errors.py +10 -0
  26. modules/esrgan_model.py +158 -0
  27. modules/esrgan_model_arch.py +80 -0
  28. modules/extras.py +222 -0
  29. modules/face_restoration.py +19 -0
  30. modules/generation_parameters_copypaste.py +101 -0
  31. modules/gfpgan_model.py +115 -0
  32. modules/hypernetworks/hypernetwork.py +314 -0
  33. modules/hypernetworks/ui.py +47 -0
  34. modules/images.py +465 -0
  35. modules/img2img.py +137 -0
  36. modules/interrogate.py +171 -0
  37. modules/ldsr_model.py +54 -0
  38. modules/ldsr_model_arch.py +222 -0
  39. modules/lowvram.py +82 -0
  40. modules/masking.py +99 -0
  41. modules/memmon.py +85 -0
  42. modules/modelloader.py +153 -0
  43. modules/ngrok.py +15 -0
  44. modules/paths.py +40 -0
  45. modules/processing.py +721 -0
  46. modules/prompt_parser.py +366 -0
  47. modules/realesrgan_model.py +133 -0
  48. modules/safe.py +110 -0
  49. modules/safety.py +42 -0
  50. modules/scripts.py +201 -0
CODEOWNERS ADDED
@@ -0,0 +1 @@
 
 
1
+ * @AUTOMATIC1111
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+ import time
4
+ import importlib
5
+ import signal
6
+ import threading
7
+
8
+ from fastapi.middleware.gzip import GZipMiddleware
9
+
10
+ from modules.paths import script_path
11
+
12
+ from modules import devices, sd_samplers
13
+ import modules.codeformer_model as codeformer
14
+ import modules.extras
15
+ import modules.face_restoration
16
+ import modules.gfpgan_model as gfpgan
17
+ import modules.img2img
18
+
19
+ import modules.lowvram
20
+ import modules.paths
21
+ import modules.scripts
22
+ import modules.sd_hijack
23
+ import modules.sd_models
24
+ import modules.shared as shared
25
+ import modules.txt2img
26
+
27
+ import modules.ui
28
+ from modules import devices
29
+ from modules import modelloader
30
+ from modules.paths import script_path
31
+ from modules.shared import cmd_opts
32
+ import modules.hypernetworks.hypernetwork
33
+
34
+
35
+ queue_lock = threading.Lock()
36
+
37
+
38
+ def wrap_queued_call(func):
39
+ def f(*args, **kwargs):
40
+ with queue_lock:
41
+ res = func(*args, **kwargs)
42
+
43
+ return res
44
+
45
+ return f
46
+
47
+
48
+ def wrap_gradio_gpu_call(func, extra_outputs=None):
49
+ def f(*args, **kwargs):
50
+ devices.torch_gc()
51
+
52
+ shared.state.sampling_step = 0
53
+ shared.state.job_count = -1
54
+ shared.state.job_no = 0
55
+ shared.state.job_timestamp = shared.state.get_job_timestamp()
56
+ shared.state.current_latent = None
57
+ shared.state.current_image = None
58
+ shared.state.current_image_sampling_step = 0
59
+ shared.state.skipped = False
60
+ shared.state.interrupted = False
61
+ shared.state.textinfo = None
62
+
63
+ with queue_lock:
64
+ res = func(*args, **kwargs)
65
+
66
+ shared.state.job = ""
67
+ shared.state.job_count = 0
68
+
69
+ devices.torch_gc()
70
+
71
+ return res
72
+
73
+ return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
74
+
75
+ def initialize():
76
+ modelloader.cleanup_models()
77
+ modules.sd_models.setup_model()
78
+ codeformer.setup_model(cmd_opts.codeformer_models_path)
79
+ gfpgan.setup_model(cmd_opts.gfpgan_models_path)
80
+ shared.face_restorers.append(modules.face_restoration.FaceRestoration())
81
+ modelloader.load_upscalers()
82
+
83
+ modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
84
+
85
+ shared.sd_model = modules.sd_models.load_model()
86
+ shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
87
+ shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
88
+ shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
89
+
90
+
91
+ def webui():
92
+ initialize()
93
+
94
+ # make the program just exit at ctrl+c without waiting for anything
95
+ def sigint_handler(sig, frame):
96
+ print(f'Interrupted with signal {sig} in {frame}')
97
+ os._exit(0)
98
+
99
+ signal.signal(signal.SIGINT, sigint_handler)
100
+
101
+ while 1:
102
+
103
+ demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
104
+
105
+ app, local_url, share_url = demo.launch(
106
+ share=cmd_opts.share,
107
+ server_name="0.0.0.0" if cmd_opts.listen else None,
108
+ server_port=cmd_opts.port,
109
+ debug=cmd_opts.gradio_debug,
110
+ auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
111
+ inbrowser=cmd_opts.autolaunch,
112
+ prevent_thread_lock=True
113
+ )
114
+
115
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
116
+
117
+ while 1:
118
+ time.sleep(0.5)
119
+ if getattr(demo, 'do_restart', False):
120
+ time.sleep(0.5)
121
+ demo.close()
122
+ time.sleep(0.5)
123
+ break
124
+
125
+ sd_samplers.set_samplers()
126
+
127
+ print('Reloading Custom Scripts')
128
+ modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
129
+ print('Reloading modules: modules.ui')
130
+ importlib.reload(modules.ui)
131
+ print('Refreshing Model List')
132
+ modules.sd_models.list_models()
133
+ print('Restarting Gradio')
134
+
135
+
136
+ if __name__ == "__main__":
137
+ webui()
artists.csv ADDED
The diff for this file is too large to render. See raw diff
 
environment-wsl2.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: automatic
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.10
7
+ - pip=22.2.2
8
+ - cudatoolkit=11.3
9
+ - pytorch=1.12.1
10
+ - torchvision=0.13.1
11
+ - numpy=1.23.1
javascript/aspectRatioOverlay.js ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ let currentWidth = null;
3
+ let currentHeight = null;
4
+ let arFrameTimeout = setTimeout(function(){},0);
5
+
6
+ function dimensionChange(e,dimname){
7
+
8
+ if(dimname == 'Width'){
9
+ currentWidth = e.target.value*1.0
10
+ }
11
+ if(dimname == 'Height'){
12
+ currentHeight = e.target.value*1.0
13
+ }
14
+
15
+ var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
16
+
17
+ if(!inImg2img){
18
+ return;
19
+ }
20
+
21
+ var img2imgMode = gradioApp().querySelector('#mode_img2img.tabs > div > button.rounded-t-lg.border-gray-200')
22
+ if(img2imgMode){
23
+ img2imgMode=img2imgMode.innerText
24
+ }else{
25
+ return;
26
+ }
27
+
28
+ var redrawImage = gradioApp().querySelector('div[data-testid=image] img');
29
+ var inpaintImage = gradioApp().querySelector('#img2maskimg div[data-testid=image] img')
30
+
31
+ var targetElement = null;
32
+
33
+ if(img2imgMode=='img2img' && redrawImage){
34
+ targetElement = redrawImage;
35
+ }else if(img2imgMode=='Inpaint' && inpaintImage){
36
+ targetElement = inpaintImage;
37
+ }
38
+
39
+ if(targetElement){
40
+
41
+ var arPreviewRect = gradioApp().querySelector('#imageARPreview');
42
+ if(!arPreviewRect){
43
+ arPreviewRect = document.createElement('div')
44
+ arPreviewRect.id = "imageARPreview";
45
+ gradioApp().getRootNode().appendChild(arPreviewRect)
46
+ }
47
+
48
+
49
+
50
+ var viewportOffset = targetElement.getBoundingClientRect();
51
+
52
+ viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
53
+
54
+ scaledx = targetElement.naturalWidth*viewportscale
55
+ scaledy = targetElement.naturalHeight*viewportscale
56
+
57
+ cleintRectTop = (viewportOffset.top+window.scrollY)
58
+ cleintRectLeft = (viewportOffset.left+window.scrollX)
59
+ cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
60
+ cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
61
+
62
+ viewRectTop = cleintRectCentreY-(scaledy/2)
63
+ viewRectLeft = cleintRectCentreX-(scaledx/2)
64
+ arRectWidth = scaledx
65
+ arRectHeight = scaledy
66
+
67
+ arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
68
+ arscaledx = currentWidth*arscale
69
+ arscaledy = currentHeight*arscale
70
+
71
+ arRectTop = cleintRectCentreY-(arscaledy/2)
72
+ arRectLeft = cleintRectCentreX-(arscaledx/2)
73
+ arRectWidth = arscaledx
74
+ arRectHeight = arscaledy
75
+
76
+ arPreviewRect.style.top = arRectTop+'px';
77
+ arPreviewRect.style.left = arRectLeft+'px';
78
+ arPreviewRect.style.width = arRectWidth+'px';
79
+ arPreviewRect.style.height = arRectHeight+'px';
80
+
81
+ clearTimeout(arFrameTimeout);
82
+ arFrameTimeout = setTimeout(function(){
83
+ arPreviewRect.style.display = 'none';
84
+ },2000);
85
+
86
+ arPreviewRect.style.display = 'block';
87
+
88
+ }
89
+
90
+ }
91
+
92
+
93
+ onUiUpdate(function(){
94
+ var arPreviewRect = gradioApp().querySelector('#imageARPreview');
95
+ if(arPreviewRect){
96
+ arPreviewRect.style.display = 'none';
97
+ }
98
+ var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
99
+ if(inImg2img){
100
+ let inputs = gradioApp().querySelectorAll('input');
101
+ inputs.forEach(function(e){
102
+ let parentLabel = e.parentElement.querySelector('label')
103
+ if(parentLabel && parentLabel.innerText){
104
+ if(!e.classList.contains('scrollwatch')){
105
+ if(parentLabel.innerText == 'Width' || parentLabel.innerText == 'Height'){
106
+ e.addEventListener('input', function(e){dimensionChange(e,parentLabel.innerText)} )
107
+ e.classList.add('scrollwatch')
108
+ }
109
+ if(parentLabel.innerText == 'Width'){
110
+ currentWidth = e.value*1.0
111
+ }
112
+ if(parentLabel.innerText == 'Height'){
113
+ currentHeight = e.value*1.0
114
+ }
115
+ }
116
+ }
117
+ })
118
+ }
119
+ });
javascript/contextMenus.js ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ contextMenuInit = function(){
3
+ let eventListenerApplied=false;
4
+ let menuSpecs = new Map();
5
+
6
+ const uid = function(){
7
+ return Date.now().toString(36) + Math.random().toString(36).substr(2);
8
+ }
9
+
10
+ function showContextMenu(event,element,menuEntries){
11
+ let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
12
+ let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
13
+
14
+ let oldMenu = gradioApp().querySelector('#context-menu')
15
+ if(oldMenu){
16
+ oldMenu.remove()
17
+ }
18
+
19
+ let tabButton = uiCurrentTab
20
+ let baseStyle = window.getComputedStyle(tabButton)
21
+
22
+ const contextMenu = document.createElement('nav')
23
+ contextMenu.id = "context-menu"
24
+ contextMenu.style.background = baseStyle.background
25
+ contextMenu.style.color = baseStyle.color
26
+ contextMenu.style.fontFamily = baseStyle.fontFamily
27
+ contextMenu.style.top = posy+'px'
28
+ contextMenu.style.left = posx+'px'
29
+
30
+
31
+
32
+ const contextMenuList = document.createElement('ul')
33
+ contextMenuList.className = 'context-menu-items';
34
+ contextMenu.append(contextMenuList);
35
+
36
+ menuEntries.forEach(function(entry){
37
+ let contextMenuEntry = document.createElement('a')
38
+ contextMenuEntry.innerHTML = entry['name']
39
+ contextMenuEntry.addEventListener("click", function(e) {
40
+ entry['func']();
41
+ })
42
+ contextMenuList.append(contextMenuEntry);
43
+
44
+ })
45
+
46
+ gradioApp().getRootNode().appendChild(contextMenu)
47
+
48
+ let menuWidth = contextMenu.offsetWidth + 4;
49
+ let menuHeight = contextMenu.offsetHeight + 4;
50
+
51
+ let windowWidth = window.innerWidth;
52
+ let windowHeight = window.innerHeight;
53
+
54
+ if ( (windowWidth - posx) < menuWidth ) {
55
+ contextMenu.style.left = windowWidth - menuWidth + "px";
56
+ }
57
+
58
+ if ( (windowHeight - posy) < menuHeight ) {
59
+ contextMenu.style.top = windowHeight - menuHeight + "px";
60
+ }
61
+
62
+ }
63
+
64
+ function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){
65
+
66
+ currentItems = menuSpecs.get(targetEmementSelector)
67
+
68
+ if(!currentItems){
69
+ currentItems = []
70
+ menuSpecs.set(targetEmementSelector,currentItems);
71
+ }
72
+ let newItem = {'id':targetEmementSelector+'_'+uid(),
73
+ 'name':entryName,
74
+ 'func':entryFunction,
75
+ 'isNew':true}
76
+
77
+ currentItems.push(newItem)
78
+ return newItem['id']
79
+ }
80
+
81
+ function removeContextMenuOption(uid){
82
+ menuSpecs.forEach(function(v,k) {
83
+ let index = -1
84
+ v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
85
+ if(index>=0){
86
+ v.splice(index, 1);
87
+ }
88
+ })
89
+ }
90
+
91
+ function addContextMenuEventListener(){
92
+ if(eventListenerApplied){
93
+ return;
94
+ }
95
+ gradioApp().addEventListener("click", function(e) {
96
+ let source = e.composedPath()[0]
97
+ if(source.id && source.id.indexOf('check_progress')>-1){
98
+ return
99
+ }
100
+
101
+ let oldMenu = gradioApp().querySelector('#context-menu')
102
+ if(oldMenu){
103
+ oldMenu.remove()
104
+ }
105
+ });
106
+ gradioApp().addEventListener("contextmenu", function(e) {
107
+ let oldMenu = gradioApp().querySelector('#context-menu')
108
+ if(oldMenu){
109
+ oldMenu.remove()
110
+ }
111
+ menuSpecs.forEach(function(v,k) {
112
+ if(e.composedPath()[0].matches(k)){
113
+ showContextMenu(e,e.composedPath()[0],v)
114
+ e.preventDefault()
115
+ return
116
+ }
117
+ })
118
+ });
119
+ eventListenerApplied=true
120
+
121
+ }
122
+
123
+ return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
124
+ }
125
+
126
+ initResponse = contextMenuInit();
127
+ appendContextMenuOption = initResponse[0];
128
+ removeContextMenuOption = initResponse[1];
129
+ addContextMenuEventListener = initResponse[2];
130
+
131
+ (function(){
132
+ //Start example Context Menu Items
133
+ let generateOnRepeat = function(genbuttonid,interruptbuttonid){
134
+ let genbutton = gradioApp().querySelector(genbuttonid);
135
+ let interruptbutton = gradioApp().querySelector(interruptbuttonid);
136
+ if(!interruptbutton.offsetParent){
137
+ genbutton.click();
138
+ }
139
+ clearInterval(window.generateOnRepeatInterval)
140
+ window.generateOnRepeatInterval = setInterval(function(){
141
+ if(!interruptbutton.offsetParent){
142
+ genbutton.click();
143
+ }
144
+ },
145
+ 500)
146
+ }
147
+
148
+ appendContextMenuOption('#txt2img_generate','Generate forever',function(){
149
+ generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
150
+ })
151
+ appendContextMenuOption('#img2img_generate','Generate forever',function(){
152
+ generateOnRepeat('#img2img_generate','#img2img_interrupt');
153
+ })
154
+
155
+ let cancelGenerateForever = function(){
156
+ clearInterval(window.generateOnRepeatInterval)
157
+ }
158
+
159
+ appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
160
+ appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
161
+ appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
162
+ appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
163
+
164
+ appendContextMenuOption('#roll','Roll three',
165
+ function(){
166
+ let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
167
+ setTimeout(function(){rollbutton.click()},100)
168
+ setTimeout(function(){rollbutton.click()},200)
169
+ setTimeout(function(){rollbutton.click()},300)
170
+ }
171
+ )
172
+ })();
173
+ //End example Context Menu Items
174
+
175
+ onUiUpdate(function(){
176
+ addContextMenuEventListener()
177
+ });
javascript/dragdrop.js ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // allows drag-dropping files into gradio image elements, and also pasting images from clipboard
2
+
3
+ function isValidImageList( files ) {
4
+ return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
5
+ }
6
+
7
+ function dropReplaceImage( imgWrap, files ) {
8
+ if ( ! isValidImageList( files ) ) {
9
+ return;
10
+ }
11
+
12
+ imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
13
+ const callback = () => {
14
+ const fileInput = imgWrap.querySelector('input[type="file"]');
15
+ if ( fileInput ) {
16
+ fileInput.files = files;
17
+ fileInput.dispatchEvent(new Event('change'));
18
+ }
19
+ };
20
+
21
+ if ( imgWrap.closest('#pnginfo_image') ) {
22
+ // special treatment for PNG Info tab, wait for fetch request to finish
23
+ const oldFetch = window.fetch;
24
+ window.fetch = async (input, options) => {
25
+ const response = await oldFetch(input, options);
26
+ if ( 'api/predict/' === input ) {
27
+ const content = await response.text();
28
+ window.fetch = oldFetch;
29
+ window.requestAnimationFrame( () => callback() );
30
+ return new Response(content, {
31
+ status: response.status,
32
+ statusText: response.statusText,
33
+ headers: response.headers
34
+ })
35
+ }
36
+ return response;
37
+ };
38
+ } else {
39
+ window.requestAnimationFrame( () => callback() );
40
+ }
41
+ }
42
+
43
+ window.document.addEventListener('dragover', e => {
44
+ const target = e.composedPath()[0];
45
+ const imgWrap = target.closest('[data-testid="image"]');
46
+ if ( !imgWrap ) {
47
+ return;
48
+ }
49
+ e.stopPropagation();
50
+ e.preventDefault();
51
+ e.dataTransfer.dropEffect = 'copy';
52
+ });
53
+
54
+ window.document.addEventListener('drop', e => {
55
+ const target = e.composedPath()[0];
56
+ const imgWrap = target.closest('[data-testid="image"]');
57
+ if ( !imgWrap ) {
58
+ return;
59
+ }
60
+ e.stopPropagation();
61
+ e.preventDefault();
62
+ const files = e.dataTransfer.files;
63
+ dropReplaceImage( imgWrap, files );
64
+ });
65
+
66
+ window.addEventListener('paste', e => {
67
+ const files = e.clipboardData.files;
68
+ if ( ! isValidImageList( files ) ) {
69
+ return;
70
+ }
71
+
72
+ const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
73
+ .filter(el => uiElementIsVisible(el));
74
+ if ( ! visibleImageFields.length ) {
75
+ return;
76
+ }
77
+
78
+ const firstFreeImageField = visibleImageFields
79
+ .filter(el => el.querySelector('input[type=file]'))?.[0];
80
+
81
+ dropReplaceImage(
82
+ firstFreeImageField ?
83
+ firstFreeImageField :
84
+ visibleImageFields[visibleImageFields.length - 1]
85
+ , files );
86
+ });
javascript/edit-attention.js ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addEventListener('keydown', (event) => {
2
+ let target = event.originalTarget || event.composedPath()[0];
3
+ if (!target.hasAttribute("placeholder")) return;
4
+ if (!target.placeholder.toLowerCase().includes("prompt")) return;
5
+
6
+ let plus = "ArrowUp"
7
+ let minus = "ArrowDown"
8
+ if (event.key != plus && event.key != minus) return;
9
+
10
+ selectionStart = target.selectionStart;
11
+ selectionEnd = target.selectionEnd;
12
+ if(selectionStart == selectionEnd) return;
13
+
14
+ event.preventDefault();
15
+
16
+ if (selectionStart == 0 || target.value[selectionStart - 1] != "(") {
17
+ target.value = target.value.slice(0, selectionStart) +
18
+ "(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" +
19
+ target.value.slice(selectionEnd);
20
+
21
+ target.focus();
22
+ target.selectionStart = selectionStart + 1;
23
+ target.selectionEnd = selectionEnd + 1;
24
+
25
+ } else {
26
+ end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
27
+ weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
28
+ if (isNaN(weight)) return;
29
+ if (event.key == minus) weight -= 0.1;
30
+ if (event.key == plus) weight += 0.1;
31
+
32
+ weight = parseFloat(weight.toPrecision(12));
33
+
34
+ target.value = target.value.slice(0, selectionEnd + 1) +
35
+ weight +
36
+ target.value.slice(selectionEnd + 1 + end - 1);
37
+
38
+ target.focus();
39
+ target.selectionStart = selectionStart;
40
+ target.selectionEnd = selectionEnd;
41
+ }
42
+ // Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its
43
+ // internal Svelte data binding remains in sync.
44
+ target.dispatchEvent(new Event("input", { bubbles: true }));
45
+ });
javascript/hints.js ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // mouseover tooltips for various UI elements
2
+
3
+ titles = {
4
+ "Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
5
+ "Sampling method": "Which algorithm to use to produce the image",
6
+ "GFPGAN": "Restore low quality faces using GFPGAN neural network",
7
+ "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help",
8
+ "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
9
+
10
+ "Batch count": "How many batches of images to create",
11
+ "Batch size": "How many image to create in a single batch",
12
+ "CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
13
+ "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
14
+ "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
15
+ "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
16
+ "\u{1f3a8}": "Add a random artist to the prompt.",
17
+ "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
18
+ "\u{1f4c2}": "Open images output directory",
19
+
20
+ "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
21
+ "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
22
+
23
+ "Just resize": "Resize image to target resolution. Unless height and width match, you will get incorrect aspect ratio.",
24
+ "Crop and resize": "Resize the image so that entirety of target resolution is filled with the image. Crop parts that stick out.",
25
+ "Resize and fill": "Resize the image so that entirety of image is inside target resolution. Fill empty space with image's colors.",
26
+
27
+ "Mask blur": "How much to blur the mask before processing, in pixels.",
28
+ "Masked content": "What to put inside the masked area before processing it with Stable Diffusion.",
29
+ "fill": "fill it with colors of the image",
30
+ "original": "keep whatever was there originally",
31
+ "latent noise": "fill it with latent space noise",
32
+ "latent nothing": "fill it with latent space zeroes",
33
+ "Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
34
+
35
+ "Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
36
+ "Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.",
37
+
38
+ "Skip": "Stop processing current image and continue processing.",
39
+ "Interrupt": "Stop processing images and return any results accumulated so far.",
40
+ "Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
41
+
42
+ "X values": "Separate values for X axis using commas.",
43
+ "Y values": "Separate values for Y axis using commas.",
44
+
45
+ "None": "Do not do anything special",
46
+ "Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)",
47
+ "X/Y plot": "Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
48
+ "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
49
+
50
+ "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
51
+ "Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order",
52
+
53
+ "Tiling": "Produce an image that can be tiled.",
54
+ "Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",
55
+
56
+ "Variation seed": "Seed of a different picture to be mixed into the generation.",
57
+ "Variation strength": "How strong of a variation to produce. At 0, there will be no effect. At 1, you will get the complete picture with variation seed (except for ancestral samplers, where you will just get something).",
58
+ "Resize seed from height": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
59
+ "Resize seed from width": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
60
+
61
+ "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
62
+
63
+ "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
64
+ "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
65
+ "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
66
+
67
+ "Loopback": "Process an image, use it as an input, repeat.",
68
+ "Loops": "How many times to repeat processing an image and using it as input for the next iteration",
69
+
70
+ "Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both",
71
+ "Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
72
+ "Apply style": "Insert selected styles into prompt fields",
73
+ "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
74
+
75
+ "Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
76
+
77
+ "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
78
+
79
+ "Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
80
+ "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
81
+
82
+ "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
83
+ "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
84
+
85
+ "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
86
+ "Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.",
87
+
88
+ "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply."
89
+ }
90
+
91
+
92
+ onUiUpdate(function(){
93
+ gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
94
+ tooltip = titles[span.textContent];
95
+
96
+ if(!tooltip){
97
+ tooltip = titles[span.value];
98
+ }
99
+
100
+ if(!tooltip){
101
+ for (const c of span.classList) {
102
+ if (c in titles) {
103
+ tooltip = titles[c];
104
+ break;
105
+ }
106
+ }
107
+ }
108
+
109
+ if(tooltip){
110
+ span.title = tooltip;
111
+ }
112
+ })
113
+
114
+ gradioApp().querySelectorAll('select').forEach(function(select){
115
+ if (select.onchange != null) return;
116
+
117
+ select.onchange = function(){
118
+ select.title = titles[select.value] || "";
119
+ }
120
+ })
121
+ })
javascript/imageMaskFix.js ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
3
+ * @see https://github.com/gradio-app/gradio/issues/1721
4
+ */
5
+ window.addEventListener( 'resize', () => imageMaskResize());
6
+ function imageMaskResize() {
7
+ const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
8
+ if ( ! canvases.length ) {
9
+ canvases_fixed = false;
10
+ window.removeEventListener( 'resize', imageMaskResize );
11
+ return;
12
+ }
13
+
14
+ const wrapper = canvases[0].closest('.touch-none');
15
+ const previewImage = wrapper.previousElementSibling;
16
+
17
+ if ( ! previewImage.complete ) {
18
+ previewImage.addEventListener( 'load', () => imageMaskResize());
19
+ return;
20
+ }
21
+
22
+ const w = previewImage.width;
23
+ const h = previewImage.height;
24
+ const nw = previewImage.naturalWidth;
25
+ const nh = previewImage.naturalHeight;
26
+ const portrait = nh > nw;
27
+ const factor = portrait;
28
+
29
+ const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
30
+ const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
31
+
32
+ wrapper.style.width = `${wW}px`;
33
+ wrapper.style.height = `${wH}px`;
34
+ wrapper.style.left = `${(w-wW)/2}px`;
35
+ wrapper.style.top = `${(h-wH)/2}px`;
36
+
37
+ canvases.forEach( c => {
38
+ c.style.width = c.style.height = '';
39
+ c.style.maxWidth = '100%';
40
+ c.style.maxHeight = '100%';
41
+ c.style.objectFit = 'contain';
42
+ });
43
+ }
44
+
45
+ onUiUpdate(() => imageMaskResize());
javascript/imageviewer.js ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // A full size 'lightbox' preview modal shown when left clicking on gallery previews
2
+ function closeModal() {
3
+ gradioApp().getElementById("lightboxModal").style.display = "none";
4
+ }
5
+
6
+ function showModal(event) {
7
+ const source = event.target || event.srcElement;
8
+ const modalImage = gradioApp().getElementById("modalImage")
9
+ const lb = gradioApp().getElementById("lightboxModal")
10
+ modalImage.src = source.src
11
+ if (modalImage.style.display === 'none') {
12
+ lb.style.setProperty('background-image', 'url(' + source.src + ')');
13
+ }
14
+ lb.style.display = "block";
15
+ lb.focus()
16
+ event.stopPropagation()
17
+ }
18
+
19
+ function negmod(n, m) {
20
+ return ((n % m) + m) % m;
21
+ }
22
+
23
+ function updateOnBackgroundChange() {
24
+ const modalImage = gradioApp().getElementById("modalImage")
25
+ if (modalImage && modalImage.offsetParent) {
26
+ let allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
27
+ let currentButton = null
28
+ allcurrentButtons.forEach(function(elem) {
29
+ if (elem.parentElement.offsetParent) {
30
+ currentButton = elem;
31
+ }
32
+ })
33
+
34
+ if (modalImage.src != currentButton.children[0].src) {
35
+ modalImage.src = currentButton.children[0].src;
36
+ if (modalImage.style.display === 'none') {
37
+ modal.style.setProperty('background-image', `url(${modalImage.src})`)
38
+ }
39
+ }
40
+ }
41
+ }
42
+
43
+ function modalImageSwitch(offset) {
44
+ var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all")
45
+ var galleryButtons = []
46
+ allgalleryButtons.forEach(function(elem) {
47
+ if (elem.parentElement.offsetParent) {
48
+ galleryButtons.push(elem);
49
+ }
50
+ })
51
+
52
+ if (galleryButtons.length > 1) {
53
+ var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
54
+ var currentButton = null
55
+ allcurrentButtons.forEach(function(elem) {
56
+ if (elem.parentElement.offsetParent) {
57
+ currentButton = elem;
58
+ }
59
+ })
60
+
61
+ var result = -1
62
+ galleryButtons.forEach(function(v, i) {
63
+ if (v == currentButton) {
64
+ result = i
65
+ }
66
+ })
67
+
68
+ if (result != -1) {
69
+ nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
70
+ nextButton.click()
71
+ const modalImage = gradioApp().getElementById("modalImage");
72
+ const modal = gradioApp().getElementById("lightboxModal");
73
+ modalImage.src = nextButton.children[0].src;
74
+ if (modalImage.style.display === 'none') {
75
+ modal.style.setProperty('background-image', `url(${modalImage.src})`)
76
+ }
77
+ setTimeout(function() {
78
+ modal.focus()
79
+ }, 10)
80
+ }
81
+ }
82
+ }
83
+
84
+ function modalNextImage(event) {
85
+ modalImageSwitch(1)
86
+ event.stopPropagation()
87
+ }
88
+
89
+ function modalPrevImage(event) {
90
+ modalImageSwitch(-1)
91
+ event.stopPropagation()
92
+ }
93
+
94
+ function modalKeyHandler(event) {
95
+ switch (event.key) {
96
+ case "ArrowLeft":
97
+ modalPrevImage(event)
98
+ break;
99
+ case "ArrowRight":
100
+ modalNextImage(event)
101
+ break;
102
+ case "Escape":
103
+ closeModal();
104
+ break;
105
+ }
106
+ }
107
+
108
+ function showGalleryImage() {
109
+ setTimeout(function() {
110
+ fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain')
111
+
112
+ if (fullImg_preview != null) {
113
+ fullImg_preview.forEach(function function_name(e) {
114
+ if (e.dataset.modded)
115
+ return;
116
+ e.dataset.modded = true;
117
+ if(e && e.parentElement.tagName == 'DIV'){
118
+ e.style.cursor='pointer'
119
+ e.addEventListener('click', function (evt) {
120
+ if(!opts.js_modal_lightbox) return;
121
+ modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
122
+ showModal(evt)
123
+ }, true);
124
+ }
125
+ });
126
+ }
127
+
128
+ }, 100);
129
+ }
130
+
131
+ function modalZoomSet(modalImage, enable) {
132
+ if (enable) {
133
+ modalImage.classList.add('modalImageFullscreen');
134
+ } else {
135
+ modalImage.classList.remove('modalImageFullscreen');
136
+ }
137
+ }
138
+
139
+ function modalZoomToggle(event) {
140
+ modalImage = gradioApp().getElementById("modalImage");
141
+ modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
142
+ event.stopPropagation()
143
+ }
144
+
145
+ function modalTileImageToggle(event) {
146
+ const modalImage = gradioApp().getElementById("modalImage");
147
+ const modal = gradioApp().getElementById("lightboxModal");
148
+ const isTiling = modalImage.style.display === 'none';
149
+ if (isTiling) {
150
+ modalImage.style.display = 'block';
151
+ modal.style.setProperty('background-image', 'none')
152
+ } else {
153
+ modalImage.style.display = 'none';
154
+ modal.style.setProperty('background-image', `url(${modalImage.src})`)
155
+ }
156
+
157
+ event.stopPropagation()
158
+ }
159
+
160
+ function galleryImageHandler(e) {
161
+ if (e && e.parentElement.tagName == 'BUTTON') {
162
+ e.onclick = showGalleryImage;
163
+ }
164
+ }
165
+
166
+ onUiUpdate(function() {
167
+ fullImg_preview = gradioApp().querySelectorAll('img.w-full')
168
+ if (fullImg_preview != null) {
169
+ fullImg_preview.forEach(galleryImageHandler);
170
+ }
171
+ updateOnBackgroundChange();
172
+ })
173
+
174
+ document.addEventListener("DOMContentLoaded", function() {
175
+ const modalFragment = document.createDocumentFragment();
176
+ const modal = document.createElement('div')
177
+ modal.onclick = closeModal;
178
+ modal.id = "lightboxModal";
179
+ modal.tabIndex = 0
180
+ modal.addEventListener('keydown', modalKeyHandler, true)
181
+
182
+ const modalControls = document.createElement('div')
183
+ modalControls.className = 'modalControls gradio-container';
184
+ modal.append(modalControls);
185
+
186
+ const modalZoom = document.createElement('span')
187
+ modalZoom.className = 'modalZoom cursor';
188
+ modalZoom.innerHTML = '&#10529;'
189
+ modalZoom.addEventListener('click', modalZoomToggle, true)
190
+ modalZoom.title = "Toggle zoomed view";
191
+ modalControls.appendChild(modalZoom)
192
+
193
+ const modalTileImage = document.createElement('span')
194
+ modalTileImage.className = 'modalTileImage cursor';
195
+ modalTileImage.innerHTML = '&#8862;'
196
+ modalTileImage.addEventListener('click', modalTileImageToggle, true)
197
+ modalTileImage.title = "Preview tiling";
198
+ modalControls.appendChild(modalTileImage)
199
+
200
+ const modalClose = document.createElement('span')
201
+ modalClose.className = 'modalClose cursor';
202
+ modalClose.innerHTML = '&times;'
203
+ modalClose.onclick = closeModal;
204
+ modalClose.title = "Close image viewer";
205
+ modalControls.appendChild(modalClose)
206
+
207
+ const modalImage = document.createElement('img')
208
+ modalImage.id = 'modalImage';
209
+ modalImage.onclick = closeModal;
210
+ modalImage.tabIndex = 0
211
+ modalImage.addEventListener('keydown', modalKeyHandler, true)
212
+ modal.appendChild(modalImage)
213
+
214
+ const modalPrev = document.createElement('a')
215
+ modalPrev.className = 'modalPrev';
216
+ modalPrev.innerHTML = '&#10094;'
217
+ modalPrev.tabIndex = 0
218
+ modalPrev.addEventListener('click', modalPrevImage, true);
219
+ modalPrev.addEventListener('keydown', modalKeyHandler, true)
220
+ modal.appendChild(modalPrev)
221
+
222
+ const modalNext = document.createElement('a')
223
+ modalNext.className = 'modalNext';
224
+ modalNext.innerHTML = '&#10095;'
225
+ modalNext.tabIndex = 0
226
+ modalNext.addEventListener('click', modalNextImage, true);
227
+ modalNext.addEventListener('keydown', modalKeyHandler, true)
228
+
229
+ modal.appendChild(modalNext)
230
+
231
+
232
+ gradioApp().getRootNode().appendChild(modal)
233
+
234
+ document.body.appendChild(modalFragment);
235
+
236
+ });
javascript/notification.js ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Monitors the gallery and sends a browser notification when the leading image is new.
2
+
3
+ let lastHeadImg = null;
4
+
5
+ notificationButton = null
6
+
7
+ onUiUpdate(function(){
8
+ if(notificationButton == null){
9
+ notificationButton = gradioApp().getElementById('request_notifications')
10
+
11
+ if(notificationButton != null){
12
+ notificationButton.addEventListener('click', function (evt) {
13
+ Notification.requestPermission();
14
+ },true);
15
+ }
16
+ }
17
+
18
+ const galleryPreviews = gradioApp().querySelectorAll('img.h-full.w-full.overflow-hidden');
19
+
20
+ if (galleryPreviews == null) return;
21
+
22
+ const headImg = galleryPreviews[0]?.src;
23
+
24
+ if (headImg == null || headImg == lastHeadImg) return;
25
+
26
+ lastHeadImg = headImg;
27
+
28
+ // play notification sound if available
29
+ gradioApp().querySelector('#audio_notification audio')?.play();
30
+
31
+ if (document.hasFocus()) return;
32
+
33
+ // Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated.
34
+ const imgs = new Set(Array.from(galleryPreviews).map(img => img.src));
35
+
36
+ const notification = new Notification(
37
+ 'Stable Diffusion',
38
+ {
39
+ body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,
40
+ icon: headImg,
41
+ image: headImg,
42
+ }
43
+ );
44
+
45
+ notification.onclick = function(_){
46
+ parent.focus();
47
+ this.close();
48
+ };
49
+ });
javascript/progressbar.js ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // code related to showing and updating progressbar shown as the image is being made
2
+ global_progressbars = {}
3
+
4
+ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
5
+ var progressbar = gradioApp().getElementById(id_progressbar)
6
+ var skip = id_skip ? gradioApp().getElementById(id_skip) : null
7
+ var interrupt = gradioApp().getElementById(id_interrupt)
8
+
9
+ if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
10
+ if(progressbar.innerText){
11
+ let newtitle = 'Stable Diffusion - ' + progressbar.innerText
12
+ if(document.title != newtitle){
13
+ document.title = newtitle;
14
+ }
15
+ }else{
16
+ let newtitle = 'Stable Diffusion'
17
+ if(document.title != newtitle){
18
+ document.title = newtitle;
19
+ }
20
+ }
21
+ }
22
+
23
+ if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
24
+ global_progressbars[id_progressbar] = progressbar
25
+
26
+ var mutationObserver = new MutationObserver(function(m){
27
+ preview = gradioApp().getElementById(id_preview)
28
+ gallery = gradioApp().getElementById(id_gallery)
29
+
30
+ if(preview != null && gallery != null){
31
+ preview.style.width = gallery.clientWidth + "px"
32
+ preview.style.height = gallery.clientHeight + "px"
33
+
34
+ var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
35
+ if(!progressDiv){
36
+ if (skip) {
37
+ skip.style.display = "none"
38
+ }
39
+ interrupt.style.display = "none"
40
+ }
41
+ }
42
+
43
+ window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
44
+ });
45
+ mutationObserver.observe( progressbar, { childList:true, subtree:true })
46
+ }
47
+ }
48
+
49
+ onUiUpdate(function(){
50
+ check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
51
+ check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
52
+ check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery')
53
+ })
54
+
55
+ function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){
56
+ btn = gradioApp().getElementById(id_part+"_check_progress");
57
+ if(btn==null) return;
58
+
59
+ btn.click();
60
+ var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
61
+ var skip = id_skip ? gradioApp().getElementById(id_skip) : null
62
+ var interrupt = gradioApp().getElementById(id_interrupt)
63
+ if(progressDiv && interrupt){
64
+ if (skip) {
65
+ skip.style.display = "block"
66
+ }
67
+ interrupt.style.display = "block"
68
+ }
69
+ }
70
+
71
+ function requestProgress(id_part){
72
+ btn = gradioApp().getElementById(id_part+"_check_progress_initial");
73
+ if(btn==null) return;
74
+
75
+ btn.click();
76
+ }
javascript/textualInversion.js ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ function start_training_textual_inversion(){
4
+ requestProgress('ti')
5
+ gradioApp().querySelector('#ti_error').innerHTML=''
6
+
7
+ return args_to_array(arguments)
8
+ }
javascript/ui.js ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // various functions for interation with ui.py not large enough to warrant putting them in separate files
2
+
3
+ function selected_gallery_index(){
4
+ var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item')
5
+ var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2')
6
+
7
+ var result = -1
8
+ buttons.forEach(function(v, i){ if(v==button) { result = i } })
9
+
10
+ return result
11
+ }
12
+
13
+ function extract_image_from_gallery(gallery){
14
+ if(gallery.length == 1){
15
+ return gallery[0]
16
+ }
17
+
18
+ index = selected_gallery_index()
19
+
20
+ if (index < 0 || index >= gallery.length){
21
+ return [null]
22
+ }
23
+
24
+ return gallery[index];
25
+ }
26
+
27
+ function args_to_array(args){
28
+ res = []
29
+ for(var i=0;i<args.length;i++){
30
+ res.push(args[i])
31
+ }
32
+ return res
33
+ }
34
+
35
+ function switch_to_txt2img(){
36
+ gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click();
37
+
38
+ return args_to_array(arguments);
39
+ }
40
+
41
+ function switch_to_img2img_img2img(){
42
+ gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
43
+ gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
44
+
45
+ return args_to_array(arguments);
46
+ }
47
+
48
+ function switch_to_img2img_inpaint(){
49
+ gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
50
+ gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
51
+
52
+ return args_to_array(arguments);
53
+ }
54
+
55
+ function switch_to_extras(){
56
+ gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click();
57
+
58
+ return args_to_array(arguments);
59
+ }
60
+
61
+ function extract_image_from_gallery_txt2img(gallery){
62
+ switch_to_txt2img()
63
+ return extract_image_from_gallery(gallery);
64
+ }
65
+
66
+ function extract_image_from_gallery_img2img(gallery){
67
+ switch_to_img2img_img2img()
68
+ return extract_image_from_gallery(gallery);
69
+ }
70
+
71
+ function extract_image_from_gallery_inpaint(gallery){
72
+ switch_to_img2img_inpaint()
73
+ return extract_image_from_gallery(gallery);
74
+ }
75
+
76
+ function extract_image_from_gallery_extras(gallery){
77
+ switch_to_extras()
78
+ return extract_image_from_gallery(gallery);
79
+ }
80
+
81
+ function get_tab_index(tabId){
82
+ var res = 0
83
+
84
+ gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i){
85
+ if(button.className.indexOf('bg-white') != -1)
86
+ res = i
87
+ })
88
+
89
+ return res
90
+ }
91
+
92
+ function create_tab_index_args(tabId, args){
93
+ var res = []
94
+ for(var i=0; i<args.length; i++){
95
+ res.push(args[i])
96
+ }
97
+
98
+ res[0] = get_tab_index(tabId)
99
+
100
+ return res
101
+ }
102
+
103
+ function get_extras_tab_index(){
104
+ const [,,...args] = [...arguments]
105
+ return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
106
+ }
107
+
108
+ function create_submit_args(args){
109
+ res = []
110
+ for(var i=0;i<args.length;i++){
111
+ res.push(args[i])
112
+ }
113
+
114
+ // As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
115
+ // This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
116
+ // I don't know why gradio is seding outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
117
+ // If gradio at some point stops sending outputs, this may break something
118
+ if(Array.isArray(res[res.length - 3])){
119
+ res[res.length - 3] = null
120
+ }
121
+
122
+ return res
123
+ }
124
+
125
+ function submit(){
126
+ requestProgress('txt2img')
127
+
128
+ return create_submit_args(arguments)
129
+ }
130
+
131
+ function submit_img2img(){
132
+ requestProgress('img2img')
133
+
134
+ res = create_submit_args(arguments)
135
+
136
+ res[0] = get_tab_index('mode_img2img')
137
+
138
+ return res
139
+ }
140
+
141
+
142
+ function ask_for_style_name(_, prompt_text, negative_prompt_text) {
143
+ name_ = prompt('Style name:')
144
+ return name_ === null ? [null, null, null]: [name_, prompt_text, negative_prompt_text]
145
+ }
146
+
147
+
148
+
149
+ opts = {}
150
+ function apply_settings(jsdata){
151
+ console.log(jsdata)
152
+
153
+ opts = JSON.parse(jsdata)
154
+
155
+ return jsdata
156
+ }
157
+
158
+ onUiUpdate(function(){
159
+ if(Object.keys(opts).length != 0) return;
160
+
161
+ json_elem = gradioApp().getElementById('settings_json')
162
+ if(json_elem == null) return;
163
+
164
+ textarea = json_elem.querySelector('textarea')
165
+ jsdata = textarea.value
166
+ opts = JSON.parse(jsdata)
167
+
168
+
169
+ Object.defineProperty(textarea, 'value', {
170
+ set: function(newValue) {
171
+ var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
172
+ var oldValue = valueProp.get.call(textarea);
173
+ valueProp.set.call(textarea, newValue);
174
+
175
+ if (oldValue != newValue) {
176
+ opts = JSON.parse(textarea.value)
177
+ }
178
+ },
179
+ get: function() {
180
+ var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
181
+ return valueProp.get.call(textarea);
182
+ }
183
+ });
184
+
185
+ json_elem.parentElement.style.display="none"
186
+
187
+ if (!txt2img_textarea) {
188
+ txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
189
+ txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
190
+ txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
191
+ }
192
+ if (!img2img_textarea) {
193
+ img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
194
+ img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
195
+ img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
196
+ }
197
+ })
198
+
199
+ let txt2img_textarea, img2img_textarea = undefined;
200
+ let wait_time = 800
201
+ let token_timeout;
202
+
203
+ function update_txt2img_tokens(...args) {
204
+ update_token_counter("txt2img_token_button")
205
+ if (args.length == 2)
206
+ return args[0]
207
+ return args;
208
+ }
209
+
210
+ function update_img2img_tokens(...args) {
211
+ update_token_counter("img2img_token_button")
212
+ if (args.length == 2)
213
+ return args[0]
214
+ return args;
215
+ }
216
+
217
+ function update_token_counter(button_id) {
218
+ if (token_timeout)
219
+ clearTimeout(token_timeout);
220
+ token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
221
+ }
222
+
223
+ function submit_prompt(event, generate_button_id) {
224
+ if (event.altKey && event.keyCode === 13) {
225
+ event.preventDefault();
226
+ gradioApp().getElementById(generate_button_id).click();
227
+ return;
228
+ }
229
+ }
230
+
231
+ function restart_reload(){
232
+ document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
233
+ setTimeout(function(){location.reload()},2000)
234
+ }
launch.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this scripts installs necessary requirements and launches main program in webui.py
2
+ import subprocess
3
+ import os
4
+ import sys
5
+ import importlib.util
6
+ import shlex
7
+ import platform
8
+
9
+ dir_repos = "repositories"
10
+ python = sys.executable
11
+ git = os.environ.get('GIT', "git")
12
+
13
+
14
+ def extract_arg(args, name):
15
+ return [x for x in args if x != name], name in args
16
+
17
+
18
+ def run(command, desc=None, errdesc=None):
19
+ if desc is not None:
20
+ print(desc)
21
+
22
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
23
+
24
+ if result.returncode != 0:
25
+
26
+ message = f"""{errdesc or 'Error running command'}.
27
+ Command: {command}
28
+ Error code: {result.returncode}
29
+ stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
30
+ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
31
+ """
32
+ raise RuntimeError(message)
33
+
34
+ return result.stdout.decode(encoding="utf8", errors="ignore")
35
+
36
+
37
+ def check_run(command):
38
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
39
+ return result.returncode == 0
40
+
41
+
42
+ def is_installed(package):
43
+ try:
44
+ spec = importlib.util.find_spec(package)
45
+ except ModuleNotFoundError:
46
+ return False
47
+
48
+ return spec is not None
49
+
50
+
51
+ def repo_dir(name):
52
+ return os.path.join(dir_repos, name)
53
+
54
+
55
+ def run_python(code, desc=None, errdesc=None):
56
+ return run(f'"{python}" -c "{code}"', desc, errdesc)
57
+
58
+
59
+ def run_pip(args, desc=None):
60
+ return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
61
+
62
+
63
+ def check_run_python(code):
64
+ return check_run(f'"{python}" -c "{code}"')
65
+
66
+
67
+ def git_clone(url, dir, name, commithash=None):
68
+ # TODO clone into temporary dir and move if successful
69
+
70
+ if os.path.exists(dir):
71
+ if commithash is None:
72
+ return
73
+
74
+ current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
75
+ if current_hash == commithash:
76
+ return
77
+
78
+ run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
79
+ run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commint for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
80
+ return
81
+
82
+ run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
83
+
84
+ if commithash is not None:
85
+ run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
86
+
87
+
88
+ def prepare_enviroment():
89
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
90
+ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
91
+ commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
92
+
93
+ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
94
+ clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
95
+
96
+ stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
97
+ taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
98
+ k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878")
99
+ codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
100
+ blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
101
+
102
+ args = shlex.split(commandline_args)
103
+
104
+ args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
105
+ xformers = '--xformers' in args
106
+ deepdanbooru = '--deepdanbooru' in args
107
+ ngrok = '--ngrok' in args
108
+
109
+ try:
110
+ commit = run(f"{git} rev-parse HEAD").strip()
111
+ except Exception:
112
+ commit = "<none>"
113
+
114
+ print(f"Python {sys.version}")
115
+ print(f"Commit hash: {commit}")
116
+
117
+ if not is_installed("torch") or not is_installed("torchvision"):
118
+ run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
119
+
120
+ if not skip_torch_cuda_test:
121
+ run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
122
+
123
+ if not is_installed("gfpgan"):
124
+ run_pip(f"install {gfpgan_package}", "gfpgan")
125
+
126
+ if not is_installed("clip"):
127
+ run_pip(f"install {clip_package}", "clip")
128
+
129
+ if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"):
130
+ if platform.system() == "Windows":
131
+ run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/c/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
132
+ elif platform.system() == "Linux":
133
+ run_pip("install xformers", "xformers")
134
+
135
+ if not is_installed("deepdanbooru") and deepdanbooru:
136
+ run_pip("install git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru")
137
+
138
+ if not is_installed("pyngrok") and ngrok:
139
+ run_pip("install pyngrok", "ngrok")
140
+
141
+ os.makedirs(dir_repos, exist_ok=True)
142
+
143
+ git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
144
+ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
145
+ git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
146
+ git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
147
+ git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
148
+
149
+ if not is_installed("lpips"):
150
+ run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
151
+
152
+ run_pip(f"install -r {requirements_file}", "requirements for Web UI")
153
+
154
+ sys.argv += args
155
+
156
+ if "--exit" in args:
157
+ print("Exiting because of --exit argument")
158
+ exit(0)
159
+
160
+
161
+ def start_webui():
162
+ print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
163
+ import webui
164
+ webui.webui()
165
+
166
+
167
+ if __name__ == "__main__":
168
+ prepare_enviroment()
169
+ start_webui()
modules/artists.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import csv
3
+ from collections import namedtuple
4
+
5
+ Artist = namedtuple("Artist", ['name', 'weight', 'category'])
6
+
7
+
8
+ class ArtistsDatabase:
9
+ def __init__(self, filename):
10
+ self.cats = set()
11
+ self.artists = []
12
+
13
+ if not os.path.exists(filename):
14
+ return
15
+
16
+ with open(filename, "r", newline='', encoding="utf8") as file:
17
+ reader = csv.DictReader(file)
18
+
19
+ for row in reader:
20
+ artist = Artist(row["artist"], float(row["score"]), row["category"])
21
+ self.artists.append(artist)
22
+ self.cats.add(artist.category)
23
+
24
+ def categories(self):
25
+ return sorted(self.cats)
modules/bsrgan_model.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import sys
3
+ import traceback
4
+
5
+ import PIL.Image
6
+ import numpy as np
7
+ import torch
8
+ from basicsr.utils.download_util import load_file_from_url
9
+
10
+ import modules.upscaler
11
+ from modules import devices, modelloader
12
+ from modules.bsrgan_model_arch import RRDBNet
13
+
14
+
15
+ class UpscalerBSRGAN(modules.upscaler.Upscaler):
16
+ def __init__(self, dirname):
17
+ self.name = "BSRGAN"
18
+ self.model_name = "BSRGAN 4x"
19
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
20
+ self.user_path = dirname
21
+ super().__init__()
22
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
23
+ scalers = []
24
+ if len(model_paths) == 0:
25
+ scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
26
+ scalers.append(scaler_data)
27
+ for file in model_paths:
28
+ if "http" in file:
29
+ name = self.model_name
30
+ else:
31
+ name = modelloader.friendly_name(file)
32
+ try:
33
+ scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
34
+ scalers.append(scaler_data)
35
+ except Exception:
36
+ print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
37
+ print(traceback.format_exc(), file=sys.stderr)
38
+ self.scalers = scalers
39
+
40
+ def do_upscale(self, img: PIL.Image, selected_file):
41
+ torch.cuda.empty_cache()
42
+ model = self.load_model(selected_file)
43
+ if model is None:
44
+ return img
45
+ model.to(devices.device_bsrgan)
46
+ torch.cuda.empty_cache()
47
+ img = np.array(img)
48
+ img = img[:, :, ::-1]
49
+ img = np.moveaxis(img, 2, 0) / 255
50
+ img = torch.from_numpy(img).float()
51
+ img = img.unsqueeze(0).to(devices.device_bsrgan)
52
+ with torch.no_grad():
53
+ output = model(img)
54
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
55
+ output = 255. * np.moveaxis(output, 0, 2)
56
+ output = output.astype(np.uint8)
57
+ output = output[:, :, ::-1]
58
+ torch.cuda.empty_cache()
59
+ return PIL.Image.fromarray(output, 'RGB')
60
+
61
+ def load_model(self, path: str):
62
+ if "http" in path:
63
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
64
+ progress=True)
65
+ else:
66
+ filename = path
67
+ if not os.path.exists(filename) or filename is None:
68
+ print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
69
+ return None
70
+ model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
71
+ model.load_state_dict(torch.load(filename), strict=True)
72
+ model.eval()
73
+ for k, v in model.named_parameters():
74
+ v.requires_grad = False
75
+ return model
76
+
modules/bsrgan_model_arch.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.init as init
6
+
7
+
8
+ def initialize_weights(net_l, scale=1):
9
+ if not isinstance(net_l, list):
10
+ net_l = [net_l]
11
+ for net in net_l:
12
+ for m in net.modules():
13
+ if isinstance(m, nn.Conv2d):
14
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
15
+ m.weight.data *= scale # for residual block
16
+ if m.bias is not None:
17
+ m.bias.data.zero_()
18
+ elif isinstance(m, nn.Linear):
19
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
20
+ m.weight.data *= scale
21
+ if m.bias is not None:
22
+ m.bias.data.zero_()
23
+ elif isinstance(m, nn.BatchNorm2d):
24
+ init.constant_(m.weight, 1)
25
+ init.constant_(m.bias.data, 0.0)
26
+
27
+
28
+ def make_layer(block, n_layers):
29
+ layers = []
30
+ for _ in range(n_layers):
31
+ layers.append(block())
32
+ return nn.Sequential(*layers)
33
+
34
+
35
+ class ResidualDenseBlock_5C(nn.Module):
36
+ def __init__(self, nf=64, gc=32, bias=True):
37
+ super(ResidualDenseBlock_5C, self).__init__()
38
+ # gc: growth channel, i.e. intermediate channels
39
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
40
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
41
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
42
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
43
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
44
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
45
+
46
+ # initialization
47
+ initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
48
+
49
+ def forward(self, x):
50
+ x1 = self.lrelu(self.conv1(x))
51
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
52
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
53
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
54
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
55
+ return x5 * 0.2 + x
56
+
57
+
58
+ class RRDB(nn.Module):
59
+ '''Residual in Residual Dense Block'''
60
+
61
+ def __init__(self, nf, gc=32):
62
+ super(RRDB, self).__init__()
63
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
64
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
65
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
66
+
67
+ def forward(self, x):
68
+ out = self.RDB1(x)
69
+ out = self.RDB2(out)
70
+ out = self.RDB3(out)
71
+ return out * 0.2 + x
72
+
73
+
74
+ class RRDBNet(nn.Module):
75
+ def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
76
+ super(RRDBNet, self).__init__()
77
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
78
+ self.sf = sf
79
+
80
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
81
+ self.RRDB_trunk = make_layer(RRDB_block_f, nb)
82
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
83
+ #### upsampling
84
+ self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
85
+ if self.sf==4:
86
+ self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
87
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
88
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
89
+
90
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
91
+
92
+ def forward(self, x):
93
+ fea = self.conv_first(x)
94
+ trunk = self.trunk_conv(self.RRDB_trunk(fea))
95
+ fea = fea + trunk
96
+
97
+ fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
98
+ if self.sf==4:
99
+ fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
100
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
101
+
102
+ return out
modules/codeformer/codeformer_arch.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
2
+
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn, Tensor
7
+ import torch.nn.functional as F
8
+ from typing import Optional, List
9
+
10
+ from modules.codeformer.vqgan_arch import *
11
+ from basicsr.utils import get_root_logger
12
+ from basicsr.utils.registry import ARCH_REGISTRY
13
+
14
+ def calc_mean_std(feat, eps=1e-5):
15
+ """Calculate mean and std for adaptive_instance_normalization.
16
+
17
+ Args:
18
+ feat (Tensor): 4D tensor.
19
+ eps (float): A small value added to the variance to avoid
20
+ divide-by-zero. Default: 1e-5.
21
+ """
22
+ size = feat.size()
23
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
24
+ b, c = size[:2]
25
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
26
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
27
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
28
+ return feat_mean, feat_std
29
+
30
+
31
+ def adaptive_instance_normalization(content_feat, style_feat):
32
+ """Adaptive instance normalization.
33
+
34
+ Adjust the reference features to have the similar color and illuminations
35
+ as those in the degradate features.
36
+
37
+ Args:
38
+ content_feat (Tensor): The reference feature.
39
+ style_feat (Tensor): The degradate features.
40
+ """
41
+ size = content_feat.size()
42
+ style_mean, style_std = calc_mean_std(style_feat)
43
+ content_mean, content_std = calc_mean_std(content_feat)
44
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
45
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
46
+
47
+
48
+ class PositionEmbeddingSine(nn.Module):
49
+ """
50
+ This is a more standard version of the position embedding, very similar to the one
51
+ used by the Attention is all you need paper, generalized to work on images.
52
+ """
53
+
54
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
55
+ super().__init__()
56
+ self.num_pos_feats = num_pos_feats
57
+ self.temperature = temperature
58
+ self.normalize = normalize
59
+ if scale is not None and normalize is False:
60
+ raise ValueError("normalize should be True if scale is passed")
61
+ if scale is None:
62
+ scale = 2 * math.pi
63
+ self.scale = scale
64
+
65
+ def forward(self, x, mask=None):
66
+ if mask is None:
67
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
68
+ not_mask = ~mask
69
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
70
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
71
+ if self.normalize:
72
+ eps = 1e-6
73
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
74
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
75
+
76
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
77
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
78
+
79
+ pos_x = x_embed[:, :, :, None] / dim_t
80
+ pos_y = y_embed[:, :, :, None] / dim_t
81
+ pos_x = torch.stack(
82
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
83
+ ).flatten(3)
84
+ pos_y = torch.stack(
85
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
86
+ ).flatten(3)
87
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
88
+ return pos
89
+
90
+ def _get_activation_fn(activation):
91
+ """Return an activation function given a string"""
92
+ if activation == "relu":
93
+ return F.relu
94
+ if activation == "gelu":
95
+ return F.gelu
96
+ if activation == "glu":
97
+ return F.glu
98
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
99
+
100
+
101
+ class TransformerSALayer(nn.Module):
102
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
103
+ super().__init__()
104
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
105
+ # Implementation of Feedforward model - MLP
106
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
107
+ self.dropout = nn.Dropout(dropout)
108
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
109
+
110
+ self.norm1 = nn.LayerNorm(embed_dim)
111
+ self.norm2 = nn.LayerNorm(embed_dim)
112
+ self.dropout1 = nn.Dropout(dropout)
113
+ self.dropout2 = nn.Dropout(dropout)
114
+
115
+ self.activation = _get_activation_fn(activation)
116
+
117
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
118
+ return tensor if pos is None else tensor + pos
119
+
120
+ def forward(self, tgt,
121
+ tgt_mask: Optional[Tensor] = None,
122
+ tgt_key_padding_mask: Optional[Tensor] = None,
123
+ query_pos: Optional[Tensor] = None):
124
+
125
+ # self attention
126
+ tgt2 = self.norm1(tgt)
127
+ q = k = self.with_pos_embed(tgt2, query_pos)
128
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
129
+ key_padding_mask=tgt_key_padding_mask)[0]
130
+ tgt = tgt + self.dropout1(tgt2)
131
+
132
+ # ffn
133
+ tgt2 = self.norm2(tgt)
134
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
135
+ tgt = tgt + self.dropout2(tgt2)
136
+ return tgt
137
+
138
+ class Fuse_sft_block(nn.Module):
139
+ def __init__(self, in_ch, out_ch):
140
+ super().__init__()
141
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
142
+
143
+ self.scale = nn.Sequential(
144
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
145
+ nn.LeakyReLU(0.2, True),
146
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
147
+
148
+ self.shift = nn.Sequential(
149
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
150
+ nn.LeakyReLU(0.2, True),
151
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
152
+
153
+ def forward(self, enc_feat, dec_feat, w=1):
154
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
155
+ scale = self.scale(enc_feat)
156
+ shift = self.shift(enc_feat)
157
+ residual = w * (dec_feat * scale + shift)
158
+ out = dec_feat + residual
159
+ return out
160
+
161
+
162
+ @ARCH_REGISTRY.register()
163
+ class CodeFormer(VQAutoEncoder):
164
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
165
+ codebook_size=1024, latent_size=256,
166
+ connect_list=['32', '64', '128', '256'],
167
+ fix_modules=['quantize','generator']):
168
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
169
+
170
+ if fix_modules is not None:
171
+ for module in fix_modules:
172
+ for param in getattr(self, module).parameters():
173
+ param.requires_grad = False
174
+
175
+ self.connect_list = connect_list
176
+ self.n_layers = n_layers
177
+ self.dim_embd = dim_embd
178
+ self.dim_mlp = dim_embd*2
179
+
180
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
181
+ self.feat_emb = nn.Linear(256, self.dim_embd)
182
+
183
+ # transformer
184
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
185
+ for _ in range(self.n_layers)])
186
+
187
+ # logits_predict head
188
+ self.idx_pred_layer = nn.Sequential(
189
+ nn.LayerNorm(dim_embd),
190
+ nn.Linear(dim_embd, codebook_size, bias=False))
191
+
192
+ self.channels = {
193
+ '16': 512,
194
+ '32': 256,
195
+ '64': 256,
196
+ '128': 128,
197
+ '256': 128,
198
+ '512': 64,
199
+ }
200
+
201
+ # after second residual block for > 16, before attn layer for ==16
202
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
203
+ # after first residual block for > 16, before attn layer for ==16
204
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
205
+
206
+ # fuse_convs_dict
207
+ self.fuse_convs_dict = nn.ModuleDict()
208
+ for f_size in self.connect_list:
209
+ in_ch = self.channels[f_size]
210
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
211
+
212
+ def _init_weights(self, module):
213
+ if isinstance(module, (nn.Linear, nn.Embedding)):
214
+ module.weight.data.normal_(mean=0.0, std=0.02)
215
+ if isinstance(module, nn.Linear) and module.bias is not None:
216
+ module.bias.data.zero_()
217
+ elif isinstance(module, nn.LayerNorm):
218
+ module.bias.data.zero_()
219
+ module.weight.data.fill_(1.0)
220
+
221
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
222
+ # ################### Encoder #####################
223
+ enc_feat_dict = {}
224
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
225
+ for i, block in enumerate(self.encoder.blocks):
226
+ x = block(x)
227
+ if i in out_list:
228
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
229
+
230
+ lq_feat = x
231
+ # ################# Transformer ###################
232
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
233
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
234
+ # BCHW -> BC(HW) -> (HW)BC
235
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
236
+ query_emb = feat_emb
237
+ # Transformer encoder
238
+ for layer in self.ft_layers:
239
+ query_emb = layer(query_emb, query_pos=pos_emb)
240
+
241
+ # output logits
242
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
243
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
244
+
245
+ if code_only: # for training stage II
246
+ # logits doesn't need softmax before cross_entropy loss
247
+ return logits, lq_feat
248
+
249
+ # ################# Quantization ###################
250
+ # if self.training:
251
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
252
+ # # b(hw)c -> bc(hw) -> bchw
253
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
254
+ # ------------
255
+ soft_one_hot = F.softmax(logits, dim=2)
256
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
257
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
258
+ # preserve gradients
259
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
260
+
261
+ if detach_16:
262
+ quant_feat = quant_feat.detach() # for training stage III
263
+ if adain:
264
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
265
+
266
+ # ################## Generator ####################
267
+ x = quant_feat
268
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
269
+
270
+ for i, block in enumerate(self.generator.blocks):
271
+ x = block(x)
272
+ if i in fuse_list: # fuse after i-th block
273
+ f_size = str(x.shape[-1])
274
+ if w>0:
275
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
276
+ out = x
277
+ # logits doesn't need softmax before cross_entropy loss
278
+ return out, logits, lq_feat
modules/codeformer/vqgan_arch.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
2
+
3
+ '''
4
+ VQGAN code, adapted from the original created by the Unleashing Transformers authors:
5
+ https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
6
+
7
+ '''
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import copy
13
+ from basicsr.utils import get_root_logger
14
+ from basicsr.utils.registry import ARCH_REGISTRY
15
+
16
+ def normalize(in_channels):
17
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
18
+
19
+
20
+ @torch.jit.script
21
+ def swish(x):
22
+ return x*torch.sigmoid(x)
23
+
24
+
25
+ # Define VQVAE classes
26
+ class VectorQuantizer(nn.Module):
27
+ def __init__(self, codebook_size, emb_dim, beta):
28
+ super(VectorQuantizer, self).__init__()
29
+ self.codebook_size = codebook_size # number of embeddings
30
+ self.emb_dim = emb_dim # dimension of embedding
31
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
32
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
33
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
34
+
35
+ def forward(self, z):
36
+ # reshape z -> (batch, height, width, channel) and flatten
37
+ z = z.permute(0, 2, 3, 1).contiguous()
38
+ z_flattened = z.view(-1, self.emb_dim)
39
+
40
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
41
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
42
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
43
+
44
+ mean_distance = torch.mean(d)
45
+ # find closest encodings
46
+ # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
47
+ min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
48
+ # [0-1], higher score, higher confidence
49
+ min_encoding_scores = torch.exp(-min_encoding_scores/10)
50
+
51
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
52
+ min_encodings.scatter_(1, min_encoding_indices, 1)
53
+
54
+ # get quantized latent vectors
55
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
56
+ # compute loss for embedding
57
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
58
+ # preserve gradients
59
+ z_q = z + (z_q - z).detach()
60
+
61
+ # perplexity
62
+ e_mean = torch.mean(min_encodings, dim=0)
63
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
64
+ # reshape back to match original input shape
65
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
66
+
67
+ return z_q, loss, {
68
+ "perplexity": perplexity,
69
+ "min_encodings": min_encodings,
70
+ "min_encoding_indices": min_encoding_indices,
71
+ "min_encoding_scores": min_encoding_scores,
72
+ "mean_distance": mean_distance
73
+ }
74
+
75
+ def get_codebook_feat(self, indices, shape):
76
+ # input indices: batch*token_num -> (batch*token_num)*1
77
+ # shape: batch, height, width, channel
78
+ indices = indices.view(-1,1)
79
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
80
+ min_encodings.scatter_(1, indices, 1)
81
+ # get quantized latent vectors
82
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
83
+
84
+ if shape is not None: # reshape back to match original input shape
85
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
86
+
87
+ return z_q
88
+
89
+
90
+ class GumbelQuantizer(nn.Module):
91
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
92
+ super().__init__()
93
+ self.codebook_size = codebook_size # number of embeddings
94
+ self.emb_dim = emb_dim # dimension of embedding
95
+ self.straight_through = straight_through
96
+ self.temperature = temp_init
97
+ self.kl_weight = kl_weight
98
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
99
+ self.embed = nn.Embedding(codebook_size, emb_dim)
100
+
101
+ def forward(self, z):
102
+ hard = self.straight_through if self.training else True
103
+
104
+ logits = self.proj(z)
105
+
106
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
107
+
108
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
109
+
110
+ # + kl divergence to the prior loss
111
+ qy = F.softmax(logits, dim=1)
112
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
113
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
114
+
115
+ return z_q, diff, {
116
+ "min_encoding_indices": min_encoding_indices
117
+ }
118
+
119
+
120
+ class Downsample(nn.Module):
121
+ def __init__(self, in_channels):
122
+ super().__init__()
123
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
124
+
125
+ def forward(self, x):
126
+ pad = (0, 1, 0, 1)
127
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
128
+ x = self.conv(x)
129
+ return x
130
+
131
+
132
+ class Upsample(nn.Module):
133
+ def __init__(self, in_channels):
134
+ super().__init__()
135
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
136
+
137
+ def forward(self, x):
138
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
139
+ x = self.conv(x)
140
+
141
+ return x
142
+
143
+
144
+ class ResBlock(nn.Module):
145
+ def __init__(self, in_channels, out_channels=None):
146
+ super(ResBlock, self).__init__()
147
+ self.in_channels = in_channels
148
+ self.out_channels = in_channels if out_channels is None else out_channels
149
+ self.norm1 = normalize(in_channels)
150
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
151
+ self.norm2 = normalize(out_channels)
152
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
153
+ if self.in_channels != self.out_channels:
154
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
155
+
156
+ def forward(self, x_in):
157
+ x = x_in
158
+ x = self.norm1(x)
159
+ x = swish(x)
160
+ x = self.conv1(x)
161
+ x = self.norm2(x)
162
+ x = swish(x)
163
+ x = self.conv2(x)
164
+ if self.in_channels != self.out_channels:
165
+ x_in = self.conv_out(x_in)
166
+
167
+ return x + x_in
168
+
169
+
170
+ class AttnBlock(nn.Module):
171
+ def __init__(self, in_channels):
172
+ super().__init__()
173
+ self.in_channels = in_channels
174
+
175
+ self.norm = normalize(in_channels)
176
+ self.q = torch.nn.Conv2d(
177
+ in_channels,
178
+ in_channels,
179
+ kernel_size=1,
180
+ stride=1,
181
+ padding=0
182
+ )
183
+ self.k = torch.nn.Conv2d(
184
+ in_channels,
185
+ in_channels,
186
+ kernel_size=1,
187
+ stride=1,
188
+ padding=0
189
+ )
190
+ self.v = torch.nn.Conv2d(
191
+ in_channels,
192
+ in_channels,
193
+ kernel_size=1,
194
+ stride=1,
195
+ padding=0
196
+ )
197
+ self.proj_out = torch.nn.Conv2d(
198
+ in_channels,
199
+ in_channels,
200
+ kernel_size=1,
201
+ stride=1,
202
+ padding=0
203
+ )
204
+
205
+ def forward(self, x):
206
+ h_ = x
207
+ h_ = self.norm(h_)
208
+ q = self.q(h_)
209
+ k = self.k(h_)
210
+ v = self.v(h_)
211
+
212
+ # compute attention
213
+ b, c, h, w = q.shape
214
+ q = q.reshape(b, c, h*w)
215
+ q = q.permute(0, 2, 1)
216
+ k = k.reshape(b, c, h*w)
217
+ w_ = torch.bmm(q, k)
218
+ w_ = w_ * (int(c)**(-0.5))
219
+ w_ = F.softmax(w_, dim=2)
220
+
221
+ # attend to values
222
+ v = v.reshape(b, c, h*w)
223
+ w_ = w_.permute(0, 2, 1)
224
+ h_ = torch.bmm(v, w_)
225
+ h_ = h_.reshape(b, c, h, w)
226
+
227
+ h_ = self.proj_out(h_)
228
+
229
+ return x+h_
230
+
231
+
232
+ class Encoder(nn.Module):
233
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
234
+ super().__init__()
235
+ self.nf = nf
236
+ self.num_resolutions = len(ch_mult)
237
+ self.num_res_blocks = num_res_blocks
238
+ self.resolution = resolution
239
+ self.attn_resolutions = attn_resolutions
240
+
241
+ curr_res = self.resolution
242
+ in_ch_mult = (1,)+tuple(ch_mult)
243
+
244
+ blocks = []
245
+ # initial convultion
246
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
247
+
248
+ # residual and downsampling blocks, with attention on smaller res (16x16)
249
+ for i in range(self.num_resolutions):
250
+ block_in_ch = nf * in_ch_mult[i]
251
+ block_out_ch = nf * ch_mult[i]
252
+ for _ in range(self.num_res_blocks):
253
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
254
+ block_in_ch = block_out_ch
255
+ if curr_res in attn_resolutions:
256
+ blocks.append(AttnBlock(block_in_ch))
257
+
258
+ if i != self.num_resolutions - 1:
259
+ blocks.append(Downsample(block_in_ch))
260
+ curr_res = curr_res // 2
261
+
262
+ # non-local attention block
263
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
264
+ blocks.append(AttnBlock(block_in_ch))
265
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
266
+
267
+ # normalise and convert to latent size
268
+ blocks.append(normalize(block_in_ch))
269
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
270
+ self.blocks = nn.ModuleList(blocks)
271
+
272
+ def forward(self, x):
273
+ for block in self.blocks:
274
+ x = block(x)
275
+
276
+ return x
277
+
278
+
279
+ class Generator(nn.Module):
280
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
281
+ super().__init__()
282
+ self.nf = nf
283
+ self.ch_mult = ch_mult
284
+ self.num_resolutions = len(self.ch_mult)
285
+ self.num_res_blocks = res_blocks
286
+ self.resolution = img_size
287
+ self.attn_resolutions = attn_resolutions
288
+ self.in_channels = emb_dim
289
+ self.out_channels = 3
290
+ block_in_ch = self.nf * self.ch_mult[-1]
291
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
292
+
293
+ blocks = []
294
+ # initial conv
295
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
296
+
297
+ # non-local attention block
298
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
299
+ blocks.append(AttnBlock(block_in_ch))
300
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
301
+
302
+ for i in reversed(range(self.num_resolutions)):
303
+ block_out_ch = self.nf * self.ch_mult[i]
304
+
305
+ for _ in range(self.num_res_blocks):
306
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
307
+ block_in_ch = block_out_ch
308
+
309
+ if curr_res in self.attn_resolutions:
310
+ blocks.append(AttnBlock(block_in_ch))
311
+
312
+ if i != 0:
313
+ blocks.append(Upsample(block_in_ch))
314
+ curr_res = curr_res * 2
315
+
316
+ blocks.append(normalize(block_in_ch))
317
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
318
+
319
+ self.blocks = nn.ModuleList(blocks)
320
+
321
+
322
+ def forward(self, x):
323
+ for block in self.blocks:
324
+ x = block(x)
325
+
326
+ return x
327
+
328
+
329
+ @ARCH_REGISTRY.register()
330
+ class VQAutoEncoder(nn.Module):
331
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
332
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
333
+ super().__init__()
334
+ logger = get_root_logger()
335
+ self.in_channels = 3
336
+ self.nf = nf
337
+ self.n_blocks = res_blocks
338
+ self.codebook_size = codebook_size
339
+ self.embed_dim = emb_dim
340
+ self.ch_mult = ch_mult
341
+ self.resolution = img_size
342
+ self.attn_resolutions = attn_resolutions
343
+ self.quantizer_type = quantizer
344
+ self.encoder = Encoder(
345
+ self.in_channels,
346
+ self.nf,
347
+ self.embed_dim,
348
+ self.ch_mult,
349
+ self.n_blocks,
350
+ self.resolution,
351
+ self.attn_resolutions
352
+ )
353
+ if self.quantizer_type == "nearest":
354
+ self.beta = beta #0.25
355
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
356
+ elif self.quantizer_type == "gumbel":
357
+ self.gumbel_num_hiddens = emb_dim
358
+ self.straight_through = gumbel_straight_through
359
+ self.kl_weight = gumbel_kl_weight
360
+ self.quantize = GumbelQuantizer(
361
+ self.codebook_size,
362
+ self.embed_dim,
363
+ self.gumbel_num_hiddens,
364
+ self.straight_through,
365
+ self.kl_weight
366
+ )
367
+ self.generator = Generator(
368
+ self.nf,
369
+ self.embed_dim,
370
+ self.ch_mult,
371
+ self.n_blocks,
372
+ self.resolution,
373
+ self.attn_resolutions
374
+ )
375
+
376
+ if model_path is not None:
377
+ chkpt = torch.load(model_path, map_location='cpu')
378
+ if 'params_ema' in chkpt:
379
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
380
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
381
+ elif 'params' in chkpt:
382
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
383
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
384
+ else:
385
+ raise ValueError(f'Wrong params!')
386
+
387
+
388
+ def forward(self, x):
389
+ x = self.encoder(x)
390
+ quant, codebook_loss, quant_stats = self.quantize(x)
391
+ x = self.generator(quant)
392
+ return x, codebook_loss, quant_stats
393
+
394
+
395
+
396
+ # patch based discriminator
397
+ @ARCH_REGISTRY.register()
398
+ class VQGANDiscriminator(nn.Module):
399
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
400
+ super().__init__()
401
+
402
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
403
+ ndf_mult = 1
404
+ ndf_mult_prev = 1
405
+ for n in range(1, n_layers): # gradually increase the number of filters
406
+ ndf_mult_prev = ndf_mult
407
+ ndf_mult = min(2 ** n, 8)
408
+ layers += [
409
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
410
+ nn.BatchNorm2d(ndf * ndf_mult),
411
+ nn.LeakyReLU(0.2, True)
412
+ ]
413
+
414
+ ndf_mult_prev = ndf_mult
415
+ ndf_mult = min(2 ** n_layers, 8)
416
+
417
+ layers += [
418
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
419
+ nn.BatchNorm2d(ndf * ndf_mult),
420
+ nn.LeakyReLU(0.2, True)
421
+ ]
422
+
423
+ layers += [
424
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
425
+ self.main = nn.Sequential(*layers)
426
+
427
+ if model_path is not None:
428
+ chkpt = torch.load(model_path, map_location='cpu')
429
+ if 'params_d' in chkpt:
430
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
431
+ elif 'params' in chkpt:
432
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
433
+ else:
434
+ raise ValueError(f'Wrong params!')
435
+
436
+ def forward(self, x):
437
+ return self.main(x)
modules/codeformer_model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import cv2
6
+ import torch
7
+
8
+ import modules.face_restoration
9
+ import modules.shared
10
+ from modules import shared, devices, modelloader
11
+ from modules.paths import script_path, models_path
12
+
13
+ # codeformer people made a choice to include modified basicsr library to their project which makes
14
+ # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
15
+ # I am making a choice to include some files from codeformer to work around this issue.
16
+ model_dir = "Codeformer"
17
+ model_path = os.path.join(models_path, model_dir)
18
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
19
+
20
+ have_codeformer = False
21
+ codeformer = None
22
+
23
+
24
+ def setup_model(dirname):
25
+ global model_path
26
+ if not os.path.exists(model_path):
27
+ os.makedirs(model_path)
28
+
29
+ path = modules.paths.paths.get("CodeFormer", None)
30
+ if path is None:
31
+ return
32
+
33
+ try:
34
+ from torchvision.transforms.functional import normalize
35
+ from modules.codeformer.codeformer_arch import CodeFormer
36
+ from basicsr.utils.download_util import load_file_from_url
37
+ from basicsr.utils import imwrite, img2tensor, tensor2img
38
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
39
+ from modules.shared import cmd_opts
40
+
41
+ net_class = CodeFormer
42
+
43
+ class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
44
+ def name(self):
45
+ return "CodeFormer"
46
+
47
+ def __init__(self, dirname):
48
+ self.net = None
49
+ self.face_helper = None
50
+ self.cmd_dir = dirname
51
+
52
+ def create_models(self):
53
+
54
+ if self.net is not None and self.face_helper is not None:
55
+ self.net.to(devices.device_codeformer)
56
+ return self.net, self.face_helper
57
+ model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
58
+ if len(model_paths) != 0:
59
+ ckpt_path = model_paths[0]
60
+ else:
61
+ print("Unable to load codeformer model.")
62
+ return None, None
63
+ net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
64
+ checkpoint = torch.load(ckpt_path)['params_ema']
65
+ net.load_state_dict(checkpoint)
66
+ net.eval()
67
+
68
+ face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
69
+
70
+ self.net = net
71
+ self.face_helper = face_helper
72
+
73
+ return net, face_helper
74
+
75
+ def send_model_to(self, device):
76
+ self.net.to(device)
77
+ self.face_helper.face_det.to(device)
78
+ self.face_helper.face_parse.to(device)
79
+
80
+ def restore(self, np_image, w=None):
81
+ np_image = np_image[:, :, ::-1]
82
+
83
+ original_resolution = np_image.shape[0:2]
84
+
85
+ self.create_models()
86
+ if self.net is None or self.face_helper is None:
87
+ return np_image
88
+
89
+ self.send_model_to(devices.device_codeformer)
90
+
91
+ self.face_helper.clean_all()
92
+ self.face_helper.read_image(np_image)
93
+ self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
94
+ self.face_helper.align_warp_face()
95
+
96
+ for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
97
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
98
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
99
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
100
+
101
+ try:
102
+ with torch.no_grad():
103
+ output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
104
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
105
+ del output
106
+ torch.cuda.empty_cache()
107
+ except Exception as error:
108
+ print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
109
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
110
+
111
+ restored_face = restored_face.astype('uint8')
112
+ self.face_helper.add_restored_face(restored_face)
113
+
114
+ self.face_helper.get_inverse_affine(None)
115
+
116
+ restored_img = self.face_helper.paste_faces_to_input_image()
117
+ restored_img = restored_img[:, :, ::-1]
118
+
119
+ if original_resolution != restored_img.shape[0:2]:
120
+ restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
121
+
122
+ self.face_helper.clean_all()
123
+
124
+ if shared.opts.face_restoration_unload:
125
+ self.send_model_to(devices.cpu)
126
+
127
+ return restored_img
128
+
129
+ global have_codeformer
130
+ have_codeformer = True
131
+
132
+ global codeformer
133
+ codeformer = FaceRestorerCodeFormer(dirname)
134
+ shared.face_restorers.append(codeformer)
135
+
136
+ except Exception:
137
+ print("Error setting up CodeFormer:", file=sys.stderr)
138
+ print(traceback.format_exc(), file=sys.stderr)
139
+
140
+ # sys.path = stored_sys_path
modules/deepbooru.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from concurrent.futures import ProcessPoolExecutor
3
+ import multiprocessing
4
+ import time
5
+ import re
6
+
7
+ re_special = re.compile(r'([\\()])')
8
+
9
+ def get_deepbooru_tags(pil_image):
10
+ """
11
+ This method is for running only one image at a time for simple use. Used to the img2img interrogate.
12
+ """
13
+ from modules import shared # prevents circular reference
14
+
15
+ try:
16
+ create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
17
+ return get_tags_from_process(pil_image)
18
+ finally:
19
+ release_process()
20
+
21
+
22
+ OPT_INCLUDE_RANKS = "include_ranks"
23
+ def create_deepbooru_opts():
24
+ from modules import shared
25
+
26
+ return {
27
+ "use_spaces": shared.opts.deepbooru_use_spaces,
28
+ "use_escape": shared.opts.deepbooru_escape,
29
+ "alpha_sort": shared.opts.deepbooru_sort_alpha,
30
+ OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
31
+ }
32
+
33
+
34
+ def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
35
+ model, tags = get_deepbooru_tags_model()
36
+ while True: # while process is running, keep monitoring queue for new image
37
+ pil_image = queue.get()
38
+ if pil_image == "QUIT":
39
+ break
40
+ else:
41
+ deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
42
+
43
+
44
+ def create_deepbooru_process(threshold, deepbooru_opts):
45
+ """
46
+ Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
47
+ to be processed in a row without reloading the model or creating a new process. To return the data, a shared
48
+ dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
49
+ to the dictionary and the method adding the image to the queue should wait for this value to be updated with
50
+ the tags.
51
+ """
52
+ from modules import shared # prevents circular reference
53
+ shared.deepbooru_process_manager = multiprocessing.Manager()
54
+ shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
55
+ shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
56
+ shared.deepbooru_process_return["value"] = -1
57
+ shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
58
+ shared.deepbooru_process.start()
59
+
60
+
61
+ def get_tags_from_process(image):
62
+ from modules import shared
63
+
64
+ shared.deepbooru_process_return["value"] = -1
65
+ shared.deepbooru_process_queue.put(image)
66
+ while shared.deepbooru_process_return["value"] == -1:
67
+ time.sleep(0.2)
68
+ caption = shared.deepbooru_process_return["value"]
69
+ shared.deepbooru_process_return["value"] = -1
70
+
71
+ return caption
72
+
73
+
74
+ def release_process():
75
+ """
76
+ Stops the deepbooru process to return used memory
77
+ """
78
+ from modules import shared # prevents circular reference
79
+ shared.deepbooru_process_queue.put("QUIT")
80
+ shared.deepbooru_process.join()
81
+ shared.deepbooru_process_queue = None
82
+ shared.deepbooru_process = None
83
+ shared.deepbooru_process_return = None
84
+ shared.deepbooru_process_manager = None
85
+
86
+ def get_deepbooru_tags_model():
87
+ import deepdanbooru as dd
88
+ import tensorflow as tf
89
+ import numpy as np
90
+ this_folder = os.path.dirname(__file__)
91
+ model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
92
+ if not os.path.exists(os.path.join(model_path, 'project.json')):
93
+ # there is no point importing these every time
94
+ import zipfile
95
+ from basicsr.utils.download_util import load_file_from_url
96
+ load_file_from_url(
97
+ r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
98
+ model_path)
99
+ with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
100
+ zip_ref.extractall(model_path)
101
+ os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
102
+
103
+ tags = dd.project.load_tags_from_project(model_path)
104
+ model = dd.project.load_model_from_project(
105
+ model_path, compile_model=True
106
+ )
107
+ return model, tags
108
+
109
+
110
+ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
111
+ import deepdanbooru as dd
112
+ import tensorflow as tf
113
+ import numpy as np
114
+
115
+ alpha_sort = deepbooru_opts['alpha_sort']
116
+ use_spaces = deepbooru_opts['use_spaces']
117
+ use_escape = deepbooru_opts['use_escape']
118
+ include_ranks = deepbooru_opts['include_ranks']
119
+
120
+ width = model.input_shape[2]
121
+ height = model.input_shape[1]
122
+ image = np.array(pil_image)
123
+ image = tf.image.resize(
124
+ image,
125
+ size=(height, width),
126
+ method=tf.image.ResizeMethod.AREA,
127
+ preserve_aspect_ratio=True,
128
+ )
129
+ image = image.numpy() # EagerTensor to np.array
130
+ image = dd.image.transform_and_pad_image(image, width, height)
131
+ image = image / 255.0
132
+ image_shape = image.shape
133
+ image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
134
+
135
+ y = model.predict(image)[0]
136
+
137
+ result_dict = {}
138
+
139
+ for i, tag in enumerate(tags):
140
+ result_dict[tag] = y[i]
141
+
142
+ unsorted_tags_in_theshold = []
143
+ result_tags_print = []
144
+ for tag in tags:
145
+ if result_dict[tag] >= threshold:
146
+ if tag.startswith("rating:"):
147
+ continue
148
+ unsorted_tags_in_theshold.append((result_dict[tag], tag))
149
+ result_tags_print.append(f'{result_dict[tag]} {tag}')
150
+
151
+ # sort tags
152
+ result_tags_out = []
153
+ sort_ndx = 0
154
+ if alpha_sort:
155
+ sort_ndx = 1
156
+
157
+ # sort by reverse by likelihood and normal for alpha, and format tag text as requested
158
+ unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
159
+ for weight, tag in unsorted_tags_in_theshold:
160
+ # note: tag_outformat will still have a colon if include_ranks is True
161
+ tag_outformat = tag.replace(':', ' ')
162
+ if use_spaces:
163
+ tag_outformat = tag_outformat.replace('_', ' ')
164
+ if use_escape:
165
+ tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
166
+ if include_ranks:
167
+ tag_outformat = f"({tag_outformat}:{weight:.3f})"
168
+
169
+ result_tags_out.append(tag_outformat)
170
+
171
+ print('\n'.join(sorted(result_tags_print, reverse=True)))
172
+
173
+ return ', '.join(result_tags_out)
modules/devices.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+
3
+ import torch
4
+
5
+ from modules import errors
6
+
7
+ # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
8
+ has_mps = getattr(torch, 'has_mps', False)
9
+
10
+ cpu = torch.device("cpu")
11
+
12
+
13
+ def get_optimal_device():
14
+ if torch.cuda.is_available():
15
+ return torch.device("cuda")
16
+
17
+ if has_mps:
18
+ return torch.device("mps")
19
+
20
+ return cpu
21
+
22
+
23
+ def torch_gc():
24
+ if torch.cuda.is_available():
25
+ torch.cuda.empty_cache()
26
+ torch.cuda.ipc_collect()
27
+
28
+
29
+ def enable_tf32():
30
+ if torch.cuda.is_available():
31
+ torch.backends.cuda.matmul.allow_tf32 = True
32
+ torch.backends.cudnn.allow_tf32 = True
33
+
34
+
35
+ errors.run(enable_tf32, "Enabling TF32")
36
+
37
+ device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
38
+ dtype = torch.float16
39
+ dtype_vae = torch.float16
40
+
41
+ def randn(seed, shape):
42
+ # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
43
+ if device.type == 'mps':
44
+ generator = torch.Generator(device=cpu)
45
+ generator.manual_seed(seed)
46
+ noise = torch.randn(shape, generator=generator, device=cpu).to(device)
47
+ return noise
48
+
49
+ torch.manual_seed(seed)
50
+ return torch.randn(shape, device=device)
51
+
52
+
53
+ def randn_without_seed(shape):
54
+ # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
55
+ if device.type == 'mps':
56
+ generator = torch.Generator(device=cpu)
57
+ noise = torch.randn(shape, generator=generator, device=cpu).to(device)
58
+ return noise
59
+
60
+ return torch.randn(shape, device=device)
61
+
62
+
63
+ def autocast(disable=False):
64
+ from modules import shared
65
+
66
+ if disable:
67
+ return contextlib.nullcontext()
68
+
69
+ if dtype == torch.float32 or shared.cmd_opts.precision == "full":
70
+ return contextlib.nullcontext()
71
+
72
+ return torch.autocast("cuda")
modules/errors.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import traceback
3
+
4
+
5
+ def run(code, task):
6
+ try:
7
+ code()
8
+ except Exception as e:
9
+ print(f"{task}: {type(e).__name__}", file=sys.stderr)
10
+ print(traceback.format_exc(), file=sys.stderr)
modules/esrgan_model.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from basicsr.utils.download_util import load_file_from_url
7
+
8
+ import modules.esrgan_model_arch as arch
9
+ from modules import shared, modelloader, images, devices
10
+ from modules.upscaler import Upscaler, UpscalerData
11
+ from modules.shared import opts
12
+
13
+
14
+ def fix_model_layers(crt_model, pretrained_net):
15
+ # this code is adapted from https://github.com/xinntao/ESRGAN
16
+ if 'conv_first.weight' in pretrained_net:
17
+ return pretrained_net
18
+
19
+ if 'model.0.weight' not in pretrained_net:
20
+ is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
21
+ if is_realesrgan:
22
+ raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
23
+ else:
24
+ raise Exception("The file is not a ESRGAN model.")
25
+
26
+ crt_net = crt_model.state_dict()
27
+ load_net_clean = {}
28
+ for k, v in pretrained_net.items():
29
+ if k.startswith('module.'):
30
+ load_net_clean[k[7:]] = v
31
+ else:
32
+ load_net_clean[k] = v
33
+ pretrained_net = load_net_clean
34
+
35
+ tbd = []
36
+ for k, v in crt_net.items():
37
+ tbd.append(k)
38
+
39
+ # directly copy
40
+ for k, v in crt_net.items():
41
+ if k in pretrained_net and pretrained_net[k].size() == v.size():
42
+ crt_net[k] = pretrained_net[k]
43
+ tbd.remove(k)
44
+
45
+ crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
46
+ crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
47
+
48
+ for k in tbd.copy():
49
+ if 'RDB' in k:
50
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
51
+ if '.weight' in k:
52
+ ori_k = ori_k.replace('.weight', '.0.weight')
53
+ elif '.bias' in k:
54
+ ori_k = ori_k.replace('.bias', '.0.bias')
55
+ crt_net[k] = pretrained_net[ori_k]
56
+ tbd.remove(k)
57
+
58
+ crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
59
+ crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
60
+ crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
61
+ crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
62
+ crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
63
+ crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
64
+ crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
65
+ crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
66
+ crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
67
+ crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
68
+
69
+ return crt_net
70
+
71
+ class UpscalerESRGAN(Upscaler):
72
+ def __init__(self, dirname):
73
+ self.name = "ESRGAN"
74
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
75
+ self.model_name = "ESRGAN_4x"
76
+ self.scalers = []
77
+ self.user_path = dirname
78
+ super().__init__()
79
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
80
+ scalers = []
81
+ if len(model_paths) == 0:
82
+ scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
83
+ scalers.append(scaler_data)
84
+ for file in model_paths:
85
+ if "http" in file:
86
+ name = self.model_name
87
+ else:
88
+ name = modelloader.friendly_name(file)
89
+
90
+ scaler_data = UpscalerData(name, file, self, 4)
91
+ self.scalers.append(scaler_data)
92
+
93
+ def do_upscale(self, img, selected_model):
94
+ model = self.load_model(selected_model)
95
+ if model is None:
96
+ return img
97
+ model.to(devices.device_esrgan)
98
+ img = esrgan_upscale(model, img)
99
+ return img
100
+
101
+ def load_model(self, path: str):
102
+ if "http" in path:
103
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
104
+ file_name="%s.pth" % self.model_name,
105
+ progress=True)
106
+ else:
107
+ filename = path
108
+ if not os.path.exists(filename) or filename is None:
109
+ print("Unable to load %s from %s" % (self.model_path, filename))
110
+ return None
111
+
112
+ pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
113
+ crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
114
+
115
+ pretrained_net = fix_model_layers(crt_model, pretrained_net)
116
+ crt_model.load_state_dict(pretrained_net)
117
+ crt_model.eval()
118
+
119
+ return crt_model
120
+
121
+
122
+ def upscale_without_tiling(model, img):
123
+ img = np.array(img)
124
+ img = img[:, :, ::-1]
125
+ img = np.moveaxis(img, 2, 0) / 255
126
+ img = torch.from_numpy(img).float()
127
+ img = img.unsqueeze(0).to(devices.device_esrgan)
128
+ with torch.no_grad():
129
+ output = model(img)
130
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
131
+ output = 255. * np.moveaxis(output, 0, 2)
132
+ output = output.astype(np.uint8)
133
+ output = output[:, :, ::-1]
134
+ return Image.fromarray(output, 'RGB')
135
+
136
+
137
+ def esrgan_upscale(model, img):
138
+ if opts.ESRGAN_tile == 0:
139
+ return upscale_without_tiling(model, img)
140
+
141
+ grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
142
+ newtiles = []
143
+ scale_factor = 1
144
+
145
+ for y, h, row in grid.tiles:
146
+ newrow = []
147
+ for tiledata in row:
148
+ x, w, tile = tiledata
149
+
150
+ output = upscale_without_tiling(model, tile)
151
+ scale_factor = output.width // tile.width
152
+
153
+ newrow.append([x * scale_factor, w * scale_factor, output])
154
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
155
+
156
+ newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
157
+ output = images.combine_grid(newgrid)
158
+ return output
modules/esrgan_model_arch.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file is taken from https://github.com/xinntao/ESRGAN
2
+
3
+ import functools
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def make_layer(block, n_layers):
10
+ layers = []
11
+ for _ in range(n_layers):
12
+ layers.append(block())
13
+ return nn.Sequential(*layers)
14
+
15
+
16
+ class ResidualDenseBlock_5C(nn.Module):
17
+ def __init__(self, nf=64, gc=32, bias=True):
18
+ super(ResidualDenseBlock_5C, self).__init__()
19
+ # gc: growth channel, i.e. intermediate channels
20
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
21
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
22
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
23
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
24
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
25
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
26
+
27
+ # initialization
28
+ # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
29
+
30
+ def forward(self, x):
31
+ x1 = self.lrelu(self.conv1(x))
32
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
33
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
34
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
35
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
36
+ return x5 * 0.2 + x
37
+
38
+
39
+ class RRDB(nn.Module):
40
+ '''Residual in Residual Dense Block'''
41
+
42
+ def __init__(self, nf, gc=32):
43
+ super(RRDB, self).__init__()
44
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
45
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
46
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
47
+
48
+ def forward(self, x):
49
+ out = self.RDB1(x)
50
+ out = self.RDB2(out)
51
+ out = self.RDB3(out)
52
+ return out * 0.2 + x
53
+
54
+
55
+ class RRDBNet(nn.Module):
56
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32):
57
+ super(RRDBNet, self).__init__()
58
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
59
+
60
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
61
+ self.RRDB_trunk = make_layer(RRDB_block_f, nb)
62
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
63
+ #### upsampling
64
+ self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
65
+ self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
66
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
67
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
68
+
69
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
70
+
71
+ def forward(self, x):
72
+ fea = self.conv_first(x)
73
+ trunk = self.trunk_conv(self.RRDB_trunk(fea))
74
+ fea = fea + trunk
75
+
76
+ fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
77
+ fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
78
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
79
+
80
+ return out
modules/extras.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ import torch
8
+ import tqdm
9
+
10
+ from modules import processing, shared, images, devices, sd_models
11
+ from modules.shared import opts
12
+ import modules.gfpgan_model
13
+ from modules.ui import plaintext_to_html
14
+ import modules.codeformer_model
15
+ import piexif
16
+ import piexif.helper
17
+ import gradio as gr
18
+
19
+
20
+ cached_images = {}
21
+
22
+
23
+ def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
24
+ devices.torch_gc()
25
+
26
+ imageArr = []
27
+ # Also keep track of original file names
28
+ imageNameArr = []
29
+
30
+ if extras_mode == 1:
31
+ #convert file to pillow image
32
+ for img in image_folder:
33
+ image = Image.open(img)
34
+ imageArr.append(image)
35
+ imageNameArr.append(os.path.splitext(img.orig_name)[0])
36
+ else:
37
+ imageArr.append(image)
38
+ imageNameArr.append(None)
39
+
40
+ outpath = opts.outdir_samples or opts.outdir_extras_samples
41
+
42
+ outputs = []
43
+ for image, image_name in zip(imageArr, imageNameArr):
44
+ if image is None:
45
+ return outputs, "Please select an input image.", ''
46
+ existing_pnginfo = image.info or {}
47
+
48
+ image = image.convert("RGB")
49
+ info = ""
50
+
51
+ if gfpgan_visibility > 0:
52
+ restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
53
+ res = Image.fromarray(restored_img)
54
+
55
+ if gfpgan_visibility < 1.0:
56
+ res = Image.blend(image, res, gfpgan_visibility)
57
+
58
+ info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
59
+ image = res
60
+
61
+ if codeformer_visibility > 0:
62
+ restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
63
+ res = Image.fromarray(restored_img)
64
+
65
+ if codeformer_visibility < 1.0:
66
+ res = Image.blend(image, res, codeformer_visibility)
67
+
68
+ info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
69
+ image = res
70
+
71
+ if resize_mode == 1:
72
+ upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
73
+ crop_info = " (crop)" if upscaling_crop else ""
74
+ info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
75
+
76
+ if upscaling_resize != 1.0:
77
+ def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
78
+ small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
79
+ pixels = tuple(np.array(small).flatten().tolist())
80
+ key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
81
+
82
+ c = cached_images.get(key)
83
+ if c is None:
84
+ upscaler = shared.sd_upscalers[scaler_index]
85
+ c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
86
+ if mode == 1 and crop:
87
+ cropped = Image.new("RGB", (resize_w, resize_h))
88
+ cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
89
+ c = cropped
90
+ cached_images[key] = c
91
+
92
+ return c
93
+
94
+ info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
95
+ res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
96
+
97
+ if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
98
+ res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
99
+ info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
100
+ res = Image.blend(res, res2, extras_upscaler_2_visibility)
101
+
102
+ image = res
103
+
104
+ while len(cached_images) > 2:
105
+ del cached_images[next(iter(cached_images.keys()))]
106
+
107
+ images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
108
+ no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
109
+ forced_filename=image_name if opts.use_original_name_batch else None)
110
+
111
+ if opts.enable_pnginfo:
112
+ image.info = existing_pnginfo
113
+ image.info["extras"] = info
114
+
115
+ outputs.append(image)
116
+
117
+ devices.torch_gc()
118
+
119
+ return outputs, plaintext_to_html(info), ''
120
+
121
+
122
+ def run_pnginfo(image):
123
+ if image is None:
124
+ return '', '', ''
125
+
126
+ items = image.info
127
+ geninfo = ''
128
+
129
+ if "exif" in image.info:
130
+ exif = piexif.load(image.info["exif"])
131
+ exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
132
+ try:
133
+ exif_comment = piexif.helper.UserComment.load(exif_comment)
134
+ except ValueError:
135
+ exif_comment = exif_comment.decode('utf8', errors="ignore")
136
+
137
+ items['exif comment'] = exif_comment
138
+ geninfo = exif_comment
139
+
140
+ for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
141
+ 'loop', 'background', 'timestamp', 'duration']:
142
+ items.pop(field, None)
143
+
144
+ geninfo = items.get('parameters', geninfo)
145
+
146
+ info = ''
147
+ for key, text in items.items():
148
+ info += f"""
149
+ <div>
150
+ <p><b>{plaintext_to_html(str(key))}</b></p>
151
+ <p>{plaintext_to_html(str(text))}</p>
152
+ </div>
153
+ """.strip()+"\n"
154
+
155
+ if len(info) == 0:
156
+ message = "Nothing found in the image."
157
+ info = f"<div><p>{message}<p></div>"
158
+
159
+ return '', geninfo, info
160
+
161
+
162
+ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
163
+ # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
164
+ def weighted_sum(theta0, theta1, alpha):
165
+ return ((1 - alpha) * theta0) + (alpha * theta1)
166
+
167
+ # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
168
+ def sigmoid(theta0, theta1, alpha):
169
+ alpha = alpha * alpha * (3 - (2 * alpha))
170
+ return theta0 + ((theta1 - theta0) * alpha)
171
+
172
+ # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
173
+ def inv_sigmoid(theta0, theta1, alpha):
174
+ import math
175
+ alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
176
+ return theta0 + ((theta1 - theta0) * alpha)
177
+
178
+ primary_model_info = sd_models.checkpoints_list[primary_model_name]
179
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
180
+
181
+ print(f"Loading {primary_model_info.filename}...")
182
+ primary_model = torch.load(primary_model_info.filename, map_location='cpu')
183
+
184
+ print(f"Loading {secondary_model_info.filename}...")
185
+ secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
186
+
187
+ theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
188
+ theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
189
+
190
+ theta_funcs = {
191
+ "Weighted Sum": weighted_sum,
192
+ "Sigmoid": sigmoid,
193
+ "Inverse Sigmoid": inv_sigmoid,
194
+ }
195
+ theta_func = theta_funcs[interp_method]
196
+
197
+ print(f"Merging...")
198
+ for key in tqdm.tqdm(theta_0.keys()):
199
+ if 'model' in key and key in theta_1:
200
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
201
+ if save_as_half:
202
+ theta_0[key] = theta_0[key].half()
203
+
204
+ for key in theta_1.keys():
205
+ if 'model' in key and key not in theta_0:
206
+ theta_0[key] = theta_1[key]
207
+ if save_as_half:
208
+ theta_0[key] = theta_0[key].half()
209
+
210
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
211
+
212
+ filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
213
+ filename = filename if custom_name == '' else (custom_name + '.ckpt')
214
+ output_modelname = os.path.join(ckpt_dir, filename)
215
+
216
+ print(f"Saving to {output_modelname}...")
217
+ torch.save(primary_model, output_modelname)
218
+
219
+ sd_models.list_models()
220
+
221
+ print(f"Checkpoint saved.")
222
+ return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
modules/face_restoration.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules import shared
2
+
3
+
4
+ class FaceRestoration:
5
+ def name(self):
6
+ return "None"
7
+
8
+ def restore(self, np_image):
9
+ return np_image
10
+
11
+
12
+ def restore_faces(np_image):
13
+ face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
14
+ if len(face_restorers) == 0:
15
+ return np_image
16
+
17
+ face_restorer = face_restorers[0]
18
+
19
+ return face_restorer.restore(np_image)
modules/generation_parameters_copypaste.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import gradio as gr
4
+ from modules.shared import script_path
5
+ from modules import shared
6
+
7
+ re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
8
+ re_param = re.compile(re_param_code)
9
+ re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
10
+ re_imagesize = re.compile(r"^(\d+)x(\d+)$")
11
+ type_of_gr_update = type(gr.update())
12
+
13
+
14
+ def parse_generation_parameters(x: str):
15
+ """parses generation parameters string, the one you see in text field under the picture in UI:
16
+ ```
17
+ girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
18
+ Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
19
+ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
20
+ ```
21
+
22
+ returns a dict with field values
23
+ """
24
+
25
+ res = {}
26
+
27
+ prompt = ""
28
+ negative_prompt = ""
29
+
30
+ done_with_prompt = False
31
+
32
+ *lines, lastline = x.strip().split("\n")
33
+ if not re_params.match(lastline):
34
+ lines.append(lastline)
35
+ lastline = ''
36
+
37
+ for i, line in enumerate(lines):
38
+ line = line.strip()
39
+ if line.startswith("Negative prompt:"):
40
+ done_with_prompt = True
41
+ line = line[16:].strip()
42
+
43
+ if done_with_prompt:
44
+ negative_prompt += ("" if negative_prompt == "" else "\n") + line
45
+ else:
46
+ prompt += ("" if prompt == "" else "\n") + line
47
+
48
+ if len(prompt) > 0:
49
+ res["Prompt"] = prompt
50
+
51
+ if len(negative_prompt) > 0:
52
+ res["Negative prompt"] = negative_prompt
53
+
54
+ for k, v in re_param.findall(lastline):
55
+ m = re_imagesize.match(v)
56
+ if m is not None:
57
+ res[k+"-1"] = m.group(1)
58
+ res[k+"-2"] = m.group(2)
59
+ else:
60
+ res[k] = v
61
+
62
+ return res
63
+
64
+
65
+ def connect_paste(button, paste_fields, input_comp, js=None):
66
+ def paste_func(prompt):
67
+ if not prompt and not shared.cmd_opts.hide_ui_dir_config:
68
+ filename = os.path.join(script_path, "params.txt")
69
+ if os.path.exists(filename):
70
+ with open(filename, "r", encoding="utf8") as file:
71
+ prompt = file.read()
72
+
73
+ params = parse_generation_parameters(prompt)
74
+ res = []
75
+
76
+ for output, key in paste_fields:
77
+ if callable(key):
78
+ v = key(params)
79
+ else:
80
+ v = params.get(key, None)
81
+
82
+ if v is None:
83
+ res.append(gr.update())
84
+ elif isinstance(v, type_of_gr_update):
85
+ res.append(v)
86
+ else:
87
+ try:
88
+ valtype = type(output.value)
89
+ val = valtype(v)
90
+ res.append(gr.update(value=val))
91
+ except Exception:
92
+ res.append(gr.update())
93
+
94
+ return res
95
+
96
+ button.click(
97
+ fn=paste_func,
98
+ _js=js,
99
+ inputs=[input_comp],
100
+ outputs=[x[0] for x in paste_fields],
101
+ )
modules/gfpgan_model.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import facexlib
6
+ import gfpgan
7
+
8
+ import modules.face_restoration
9
+ from modules import shared, devices, modelloader
10
+ from modules.paths import models_path
11
+
12
+ model_dir = "GFPGAN"
13
+ user_path = None
14
+ model_path = os.path.join(models_path, model_dir)
15
+ model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
16
+ have_gfpgan = False
17
+ loaded_gfpgan_model = None
18
+
19
+
20
+ def gfpgann():
21
+ global loaded_gfpgan_model
22
+ global model_path
23
+ if loaded_gfpgan_model is not None:
24
+ loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
25
+ return loaded_gfpgan_model
26
+
27
+ if gfpgan_constructor is None:
28
+ return None
29
+
30
+ models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
31
+ if len(models) == 1 and "http" in models[0]:
32
+ model_file = models[0]
33
+ elif len(models) != 0:
34
+ latest_file = max(models, key=os.path.getctime)
35
+ model_file = latest_file
36
+ else:
37
+ print("Unable to load gfpgan model!")
38
+ return None
39
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
40
+ loaded_gfpgan_model = model
41
+
42
+ return model
43
+
44
+
45
+ def send_model_to(model, device):
46
+ model.gfpgan.to(device)
47
+ model.face_helper.face_det.to(device)
48
+ model.face_helper.face_parse.to(device)
49
+
50
+
51
+ def gfpgan_fix_faces(np_image):
52
+ model = gfpgann()
53
+ if model is None:
54
+ return np_image
55
+
56
+ send_model_to(model, devices.device_gfpgan)
57
+
58
+ np_image_bgr = np_image[:, :, ::-1]
59
+ cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
60
+ np_image = gfpgan_output_bgr[:, :, ::-1]
61
+
62
+ model.face_helper.clean_all()
63
+
64
+ if shared.opts.face_restoration_unload:
65
+ send_model_to(model, devices.cpu)
66
+
67
+ return np_image
68
+
69
+
70
+ gfpgan_constructor = None
71
+
72
+
73
+ def setup_model(dirname):
74
+ global model_path
75
+ if not os.path.exists(model_path):
76
+ os.makedirs(model_path)
77
+
78
+ try:
79
+ from gfpgan import GFPGANer
80
+ from facexlib import detection, parsing
81
+ global user_path
82
+ global have_gfpgan
83
+ global gfpgan_constructor
84
+
85
+ load_file_from_url_orig = gfpgan.utils.load_file_from_url
86
+ facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
87
+ facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
88
+
89
+ def my_load_file_from_url(**kwargs):
90
+ return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
91
+
92
+ def facex_load_file_from_url(**kwargs):
93
+ return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
94
+
95
+ def facex_load_file_from_url2(**kwargs):
96
+ return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
97
+
98
+ gfpgan.utils.load_file_from_url = my_load_file_from_url
99
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
100
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
101
+ user_path = dirname
102
+ have_gfpgan = True
103
+ gfpgan_constructor = GFPGANer
104
+
105
+ class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
106
+ def name(self):
107
+ return "GFPGAN"
108
+
109
+ def restore(self, np_image):
110
+ return gfpgan_fix_faces(np_image)
111
+
112
+ shared.face_restorers.append(FaceRestorerGFPGAN())
113
+ except Exception:
114
+ print("Error setting up GFPGAN:", file=sys.stderr)
115
+ print(traceback.format_exc(), file=sys.stderr)
modules/hypernetworks/hypernetwork.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import glob
3
+ import html
4
+ import os
5
+ import sys
6
+ import traceback
7
+ import tqdm
8
+
9
+ import torch
10
+
11
+ from ldm.util import default
12
+ from modules import devices, shared, processing, sd_models
13
+ import torch
14
+ from torch import einsum
15
+ from einops import rearrange, repeat
16
+ import modules.textual_inversion.dataset
17
+ from modules.textual_inversion.learn_schedule import LearnRateScheduler
18
+
19
+
20
+ class HypernetworkModule(torch.nn.Module):
21
+ multiplier = 1.0
22
+
23
+ def __init__(self, dim, state_dict=None):
24
+ super().__init__()
25
+
26
+ self.linear1 = torch.nn.Linear(dim, dim * 2)
27
+ self.linear2 = torch.nn.Linear(dim * 2, dim)
28
+
29
+ if state_dict is not None:
30
+ self.load_state_dict(state_dict, strict=True)
31
+ else:
32
+
33
+ self.linear1.weight.data.normal_(mean=0.0, std=0.01)
34
+ self.linear1.bias.data.zero_()
35
+ self.linear2.weight.data.normal_(mean=0.0, std=0.01)
36
+ self.linear2.bias.data.zero_()
37
+
38
+ self.to(devices.device)
39
+
40
+ def forward(self, x):
41
+ return x + (self.linear2(self.linear1(x))) * self.multiplier
42
+
43
+
44
+ def apply_strength(value=None):
45
+ HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
46
+
47
+
48
+ class Hypernetwork:
49
+ filename = None
50
+ name = None
51
+
52
+ def __init__(self, name=None, enable_sizes=None):
53
+ self.filename = None
54
+ self.name = name
55
+ self.layers = {}
56
+ self.step = 0
57
+ self.sd_checkpoint = None
58
+ self.sd_checkpoint_name = None
59
+
60
+ for size in enable_sizes or []:
61
+ self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
62
+
63
+ def weights(self):
64
+ res = []
65
+
66
+ for k, layers in self.layers.items():
67
+ for layer in layers:
68
+ layer.train()
69
+ res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
70
+
71
+ return res
72
+
73
+ def save(self, filename):
74
+ state_dict = {}
75
+
76
+ for k, v in self.layers.items():
77
+ state_dict[k] = (v[0].state_dict(), v[1].state_dict())
78
+
79
+ state_dict['step'] = self.step
80
+ state_dict['name'] = self.name
81
+ state_dict['sd_checkpoint'] = self.sd_checkpoint
82
+ state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
83
+
84
+ torch.save(state_dict, filename)
85
+
86
+ def load(self, filename):
87
+ self.filename = filename
88
+ if self.name is None:
89
+ self.name = os.path.splitext(os.path.basename(filename))[0]
90
+
91
+ state_dict = torch.load(filename, map_location='cpu')
92
+
93
+ for size, sd in state_dict.items():
94
+ if type(size) == int:
95
+ self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
96
+
97
+ self.name = state_dict.get('name', self.name)
98
+ self.step = state_dict.get('step', 0)
99
+ self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
100
+ self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
101
+
102
+
103
+ def list_hypernetworks(path):
104
+ res = {}
105
+ for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
106
+ name = os.path.splitext(os.path.basename(filename))[0]
107
+ res[name] = filename
108
+ return res
109
+
110
+
111
+ def load_hypernetwork(filename):
112
+ path = shared.hypernetworks.get(filename, None)
113
+ if path is not None:
114
+ print(f"Loading hypernetwork {filename}")
115
+ try:
116
+ shared.loaded_hypernetwork = Hypernetwork()
117
+ shared.loaded_hypernetwork.load(path)
118
+
119
+ except Exception:
120
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
121
+ print(traceback.format_exc(), file=sys.stderr)
122
+ else:
123
+ if shared.loaded_hypernetwork is not None:
124
+ print(f"Unloading hypernetwork")
125
+
126
+ shared.loaded_hypernetwork = None
127
+
128
+
129
+ def find_closest_hypernetwork_name(search: str):
130
+ if not search:
131
+ return None
132
+ search = search.lower()
133
+ applicable = [name for name in shared.hypernetworks if search in name.lower()]
134
+ if not applicable:
135
+ return None
136
+ applicable = sorted(applicable, key=lambda name: len(name))
137
+ return applicable[0]
138
+
139
+
140
+ def apply_hypernetwork(hypernetwork, context, layer=None):
141
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
142
+
143
+ if hypernetwork_layers is None:
144
+ return context, context
145
+
146
+ if layer is not None:
147
+ layer.hyper_k = hypernetwork_layers[0]
148
+ layer.hyper_v = hypernetwork_layers[1]
149
+
150
+ context_k = hypernetwork_layers[0](context)
151
+ context_v = hypernetwork_layers[1](context)
152
+ return context_k, context_v
153
+
154
+
155
+ def attention_CrossAttention_forward(self, x, context=None, mask=None):
156
+ h = self.heads
157
+
158
+ q = self.to_q(x)
159
+ context = default(context, x)
160
+
161
+ context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
162
+ k = self.to_k(context_k)
163
+ v = self.to_v(context_v)
164
+
165
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
166
+
167
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
168
+
169
+ if mask is not None:
170
+ mask = rearrange(mask, 'b ... -> b (...)')
171
+ max_neg_value = -torch.finfo(sim.dtype).max
172
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
173
+ sim.masked_fill_(~mask, max_neg_value)
174
+
175
+ # attention, what we cannot get enough of
176
+ attn = sim.softmax(dim=-1)
177
+
178
+ out = einsum('b i j, b j d -> b i d', attn, v)
179
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
180
+ return self.to_out(out)
181
+
182
+
183
+ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
184
+ assert hypernetwork_name, 'hypernetwork not selected'
185
+
186
+ path = shared.hypernetworks.get(hypernetwork_name, None)
187
+ shared.loaded_hypernetwork = Hypernetwork()
188
+ shared.loaded_hypernetwork.load(path)
189
+
190
+ shared.state.textinfo = "Initializing hypernetwork training..."
191
+ shared.state.job_count = steps
192
+
193
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
194
+
195
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
196
+ unload = shared.opts.unload_models_when_training
197
+
198
+ if save_hypernetwork_every > 0:
199
+ hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
200
+ os.makedirs(hypernetwork_dir, exist_ok=True)
201
+ else:
202
+ hypernetwork_dir = None
203
+
204
+ if create_image_every > 0:
205
+ images_dir = os.path.join(log_directory, "images")
206
+ os.makedirs(images_dir, exist_ok=True)
207
+ else:
208
+ images_dir = None
209
+
210
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
211
+ with torch.autocast("cuda"):
212
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
213
+
214
+ if unload:
215
+ shared.sd_model.cond_stage_model.to(devices.cpu)
216
+ shared.sd_model.first_stage_model.to(devices.cpu)
217
+
218
+ hypernetwork = shared.loaded_hypernetwork
219
+ weights = hypernetwork.weights()
220
+ for weight in weights:
221
+ weight.requires_grad = True
222
+
223
+ losses = torch.zeros((32,))
224
+
225
+ last_saved_file = "<none>"
226
+ last_saved_image = "<none>"
227
+
228
+ ititial_step = hypernetwork.step or 0
229
+ if ititial_step > steps:
230
+ return hypernetwork, filename
231
+
232
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
233
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
234
+
235
+ pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
236
+ for i, entry in pbar:
237
+ hypernetwork.step = i + ititial_step
238
+
239
+ scheduler.apply(optimizer, hypernetwork.step)
240
+ if scheduler.finished:
241
+ break
242
+
243
+ if shared.state.interrupted:
244
+ break
245
+
246
+ with torch.autocast("cuda"):
247
+ cond = entry.cond.to(devices.device)
248
+ x = entry.latent.to(devices.device)
249
+ loss = shared.sd_model(x.unsqueeze(0), cond)[0]
250
+ del x
251
+ del cond
252
+
253
+ losses[hypernetwork.step % losses.shape[0]] = loss.item()
254
+
255
+ optimizer.zero_grad()
256
+ loss.backward()
257
+ optimizer.step()
258
+
259
+ pbar.set_description(f"loss: {losses.mean():.7f}")
260
+
261
+ if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
262
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
263
+ hypernetwork.save(last_saved_file)
264
+
265
+ if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
266
+ last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
267
+
268
+ preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
269
+
270
+ optimizer.zero_grad()
271
+ shared.sd_model.cond_stage_model.to(devices.device)
272
+ shared.sd_model.first_stage_model.to(devices.device)
273
+
274
+ p = processing.StableDiffusionProcessingTxt2Img(
275
+ sd_model=shared.sd_model,
276
+ prompt=preview_text,
277
+ steps=20,
278
+ do_not_save_grid=True,
279
+ do_not_save_samples=True,
280
+ )
281
+
282
+ processed = processing.process_images(p)
283
+ image = processed.images[0] if len(processed.images)>0 else None
284
+
285
+ if unload:
286
+ shared.sd_model.cond_stage_model.to(devices.cpu)
287
+ shared.sd_model.first_stage_model.to(devices.cpu)
288
+
289
+ if image is not None:
290
+ shared.state.current_image = image
291
+ image.save(last_saved_image)
292
+ last_saved_image += f", prompt: {preview_text}"
293
+
294
+ shared.state.job_no = hypernetwork.step
295
+
296
+ shared.state.textinfo = f"""
297
+ <p>
298
+ Loss: {losses.mean():.7f}<br/>
299
+ Step: {hypernetwork.step}<br/>
300
+ Last prompt: {html.escape(entry.cond_text)}<br/>
301
+ Last saved embedding: {html.escape(last_saved_file)}<br/>
302
+ Last saved image: {html.escape(last_saved_image)}<br/>
303
+ </p>
304
+ """
305
+
306
+ checkpoint = sd_models.select_checkpoint()
307
+
308
+ hypernetwork.sd_checkpoint = checkpoint.hash
309
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
310
+ hypernetwork.save(filename)
311
+
312
+ return hypernetwork, filename
313
+
314
+
modules/hypernetworks/ui.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import os
3
+
4
+ import gradio as gr
5
+
6
+ import modules.textual_inversion.textual_inversion
7
+ import modules.textual_inversion.preprocess
8
+ from modules import sd_hijack, shared, devices
9
+ from modules.hypernetworks import hypernetwork
10
+
11
+
12
+ def create_hypernetwork(name, enable_sizes):
13
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
14
+ assert not os.path.exists(fn), f"file {fn} already exists"
15
+
16
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
17
+ hypernet.save(fn)
18
+
19
+ shared.reload_hypernetworks()
20
+
21
+ return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
22
+
23
+
24
+ def train_hypernetwork(*args):
25
+
26
+ initial_hypernetwork = shared.loaded_hypernetwork
27
+
28
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
29
+
30
+ try:
31
+ sd_hijack.undo_optimizations()
32
+
33
+ hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
34
+
35
+ res = f"""
36
+ Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
37
+ Hypernetwork saved to {html.escape(filename)}
38
+ """
39
+ return res, ""
40
+ except Exception:
41
+ raise
42
+ finally:
43
+ shared.loaded_hypernetwork = initial_hypernetwork
44
+ shared.sd_model.cond_stage_model.to(devices.device)
45
+ shared.sd_model.first_stage_model.to(devices.device)
46
+ sd_hijack.apply_optimizations()
47
+
modules/images.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import math
3
+ import os
4
+ from collections import namedtuple
5
+ import re
6
+
7
+ import numpy as np
8
+ import piexif
9
+ import piexif.helper
10
+ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
11
+ from fonts.ttf import Roboto
12
+ import string
13
+
14
+ from modules import sd_samplers, shared
15
+ from modules.shared import opts, cmd_opts
16
+
17
+ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
18
+
19
+
20
+ def image_grid(imgs, batch_size=1, rows=None):
21
+ if rows is None:
22
+ if opts.n_rows > 0:
23
+ rows = opts.n_rows
24
+ elif opts.n_rows == 0:
25
+ rows = batch_size
26
+ else:
27
+ rows = math.sqrt(len(imgs))
28
+ rows = round(rows)
29
+
30
+ cols = math.ceil(len(imgs) / rows)
31
+
32
+ w, h = imgs[0].size
33
+ grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
34
+
35
+ for i, img in enumerate(imgs):
36
+ grid.paste(img, box=(i % cols * w, i // cols * h))
37
+
38
+ return grid
39
+
40
+
41
+ Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
42
+
43
+
44
+ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
45
+ w = image.width
46
+ h = image.height
47
+
48
+ non_overlap_width = tile_w - overlap
49
+ non_overlap_height = tile_h - overlap
50
+
51
+ cols = math.ceil((w - overlap) / non_overlap_width)
52
+ rows = math.ceil((h - overlap) / non_overlap_height)
53
+
54
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
55
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
56
+
57
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
58
+ for row in range(rows):
59
+ row_images = []
60
+
61
+ y = int(row * dy)
62
+
63
+ if y + tile_h >= h:
64
+ y = h - tile_h
65
+
66
+ for col in range(cols):
67
+ x = int(col * dx)
68
+
69
+ if x + tile_w >= w:
70
+ x = w - tile_w
71
+
72
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
73
+
74
+ row_images.append([x, tile_w, tile])
75
+
76
+ grid.tiles.append([y, tile_h, row_images])
77
+
78
+ return grid
79
+
80
+
81
+ def combine_grid(grid):
82
+ def make_mask_image(r):
83
+ r = r * 255 / grid.overlap
84
+ r = r.astype(np.uint8)
85
+ return Image.fromarray(r, 'L')
86
+
87
+ mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
88
+ mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
89
+
90
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
91
+ for y, h, row in grid.tiles:
92
+ combined_row = Image.new("RGB", (grid.image_w, h))
93
+ for x, w, tile in row:
94
+ if x == 0:
95
+ combined_row.paste(tile, (0, 0))
96
+ continue
97
+
98
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
99
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
100
+
101
+ if y == 0:
102
+ combined_image.paste(combined_row, (0, 0))
103
+ continue
104
+
105
+ combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
106
+ combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
107
+
108
+ return combined_image
109
+
110
+
111
+ class GridAnnotation:
112
+ def __init__(self, text='', is_active=True):
113
+ self.text = text
114
+ self.is_active = is_active
115
+ self.size = None
116
+
117
+
118
+ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
119
+ def wrap(drawing, text, font, line_length):
120
+ lines = ['']
121
+ for word in text.split():
122
+ line = f'{lines[-1]} {word}'.strip()
123
+ if drawing.textlength(line, font=font) <= line_length:
124
+ lines[-1] = line
125
+ else:
126
+ lines.append(word)
127
+ return lines
128
+
129
+ def draw_texts(drawing, draw_x, draw_y, lines):
130
+ for i, line in enumerate(lines):
131
+ drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
132
+
133
+ if not line.is_active:
134
+ drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
135
+
136
+ draw_y += line.size[1] + line_spacing
137
+
138
+ fontsize = (width + height) // 25
139
+ line_spacing = fontsize // 2
140
+
141
+ try:
142
+ fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
143
+ except Exception:
144
+ fnt = ImageFont.truetype(Roboto, fontsize)
145
+
146
+ color_active = (0, 0, 0)
147
+ color_inactive = (153, 153, 153)
148
+
149
+ pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
150
+
151
+ cols = im.width // width
152
+ rows = im.height // height
153
+
154
+ assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
155
+ assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
156
+
157
+ calc_img = Image.new("RGB", (1, 1), "white")
158
+ calc_d = ImageDraw.Draw(calc_img)
159
+
160
+ for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
161
+ items = [] + texts
162
+ texts.clear()
163
+
164
+ for line in items:
165
+ wrapped = wrap(calc_d, line.text, fnt, allowed_width)
166
+ texts += [GridAnnotation(x, line.is_active) for x in wrapped]
167
+
168
+ for line in texts:
169
+ bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
170
+ line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
171
+
172
+ hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
173
+ ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
174
+ ver_texts]
175
+
176
+ pad_top = max(hor_text_heights) + line_spacing * 2
177
+
178
+ result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
179
+ result.paste(im, (pad_left, pad_top))
180
+
181
+ d = ImageDraw.Draw(result)
182
+
183
+ for col in range(cols):
184
+ x = pad_left + width * col + width / 2
185
+ y = pad_top / 2 - hor_text_heights[col] / 2
186
+
187
+ draw_texts(d, x, y, hor_texts[col])
188
+
189
+ for row in range(rows):
190
+ x = pad_left / 2
191
+ y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
192
+
193
+ draw_texts(d, x, y, ver_texts[row])
194
+
195
+ return result
196
+
197
+
198
+ def draw_prompt_matrix(im, width, height, all_prompts):
199
+ prompts = all_prompts[1:]
200
+ boundary = math.ceil(len(prompts) / 2)
201
+
202
+ prompts_horiz = prompts[:boundary]
203
+ prompts_vert = prompts[boundary:]
204
+
205
+ hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
206
+ ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
207
+
208
+ return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
209
+
210
+
211
+ def resize_image(resize_mode, im, width, height):
212
+ def resize(im, w, h):
213
+ if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
214
+ return im.resize((w, h), resample=LANCZOS)
215
+
216
+ scale = max(w / im.width, h / im.height)
217
+
218
+ if scale > 1.0:
219
+ upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
220
+ assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
221
+
222
+ upscaler = upscalers[0]
223
+ im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
224
+
225
+ if im.width != w or im.height != h:
226
+ im = im.resize((w, h), resample=LANCZOS)
227
+
228
+ return im
229
+
230
+ if resize_mode == 0:
231
+ res = resize(im, width, height)
232
+
233
+ elif resize_mode == 1:
234
+ ratio = width / height
235
+ src_ratio = im.width / im.height
236
+
237
+ src_w = width if ratio > src_ratio else im.width * height // im.height
238
+ src_h = height if ratio <= src_ratio else im.height * width // im.width
239
+
240
+ resized = resize(im, src_w, src_h)
241
+ res = Image.new("RGB", (width, height))
242
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
243
+
244
+ else:
245
+ ratio = width / height
246
+ src_ratio = im.width / im.height
247
+
248
+ src_w = width if ratio < src_ratio else im.width * height // im.height
249
+ src_h = height if ratio >= src_ratio else im.height * width // im.width
250
+
251
+ resized = resize(im, src_w, src_h)
252
+ res = Image.new("RGB", (width, height))
253
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
254
+
255
+ if ratio < src_ratio:
256
+ fill_height = height // 2 - src_h // 2
257
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
258
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
259
+ elif ratio > src_ratio:
260
+ fill_width = width // 2 - src_w // 2
261
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
262
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
263
+
264
+ return res
265
+
266
+
267
+ invalid_filename_chars = '<>:"/\\|?*\n'
268
+ invalid_filename_prefix = ' '
269
+ invalid_filename_postfix = ' .'
270
+ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
271
+ max_filename_part_length = 128
272
+
273
+
274
+ def sanitize_filename_part(text, replace_spaces=True):
275
+ if replace_spaces:
276
+ text = text.replace(' ', '_')
277
+
278
+ text = text.translate({ord(x): '_' for x in invalid_filename_chars})
279
+ text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
280
+ text = text.rstrip(invalid_filename_postfix)
281
+ return text
282
+
283
+
284
+ def apply_filename_pattern(x, p, seed, prompt):
285
+ max_prompt_words = opts.directories_max_prompt_words
286
+
287
+ if seed is not None:
288
+ x = x.replace("[seed]", str(seed))
289
+
290
+ if p is not None:
291
+ x = x.replace("[steps]", str(p.steps))
292
+ x = x.replace("[cfg]", str(p.cfg_scale))
293
+ x = x.replace("[width]", str(p.width))
294
+ x = x.replace("[height]", str(p.height))
295
+ x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
296
+ x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
297
+
298
+ x = x.replace("[model_hash]", getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash))
299
+ x = x.replace("[date]", datetime.date.today().isoformat())
300
+ x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
301
+ x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp))
302
+
303
+ # Apply [prompt] at last. Because it may contain any replacement word.^M
304
+ if prompt is not None:
305
+ x = x.replace("[prompt]", sanitize_filename_part(prompt))
306
+ if "[prompt_no_styles]" in x:
307
+ prompt_no_style = prompt
308
+ for style in shared.prompt_styles.get_style_prompts(p.styles):
309
+ if len(style) > 0:
310
+ style_parts = [y for y in style.split("{prompt}")]
311
+ for part in style_parts:
312
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
313
+ prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
314
+ x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
315
+
316
+ x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
317
+ if "[prompt_words]" in x:
318
+ words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
319
+ if len(words) == 0:
320
+ words = ["empty"]
321
+ x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
322
+
323
+ if cmd_opts.hide_ui_dir_config:
324
+ x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
325
+
326
+ return x
327
+
328
+
329
+ def get_next_sequence_number(path, basename):
330
+ """
331
+ Determines and returns the next sequence number to use when saving an image in the specified directory.
332
+
333
+ The sequence starts at 0.
334
+ """
335
+ result = -1
336
+ if basename != '':
337
+ basename = basename + "-"
338
+
339
+ prefix_length = len(basename)
340
+ for p in os.listdir(path):
341
+ if p.startswith(basename):
342
+ l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
343
+ try:
344
+ result = max(int(l[0]), result)
345
+ except ValueError:
346
+ pass
347
+
348
+ return result + 1
349
+
350
+
351
+ def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
352
+ '''Save an image.
353
+
354
+ Args:
355
+ image (`PIL.Image`):
356
+ The image to be saved.
357
+ path (`str`):
358
+ The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
359
+ basename (`str`):
360
+ The base filename which will be applied to `filename pattern`.
361
+ seed, prompt, short_filename,
362
+ extension (`str`):
363
+ Image file extension, default is `png`.
364
+ pngsectionname (`str`):
365
+ Specify the name of the section which `info` will be saved in.
366
+ info (`str` or `PngImagePlugin.iTXt`):
367
+ PNG info chunks.
368
+ existing_info (`dict`):
369
+ Additional PNG info. `existing_info == {pngsectionname: info, ...}`
370
+ no_prompt:
371
+ TODO I don't know its meaning.
372
+ p (`StableDiffusionProcessing`)
373
+ forced_filename (`str`):
374
+ If specified, `basename` and filename pattern will be ignored.
375
+ save_to_dirs (bool):
376
+ If true, the image will be saved into a subdirectory of `path`.
377
+
378
+ Returns: (fullfn, txt_fullfn)
379
+ fullfn (`str`):
380
+ The full path of the saved imaged.
381
+ txt_fullfn (`str` or None):
382
+ If a text file is saved for this image, this will be its full path. Otherwise None.
383
+ '''
384
+ if short_filename or prompt is None or seed is None:
385
+ file_decoration = ""
386
+ elif opts.save_to_dirs:
387
+ file_decoration = opts.samples_filename_pattern or "[seed]"
388
+ else:
389
+ file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
390
+
391
+ if file_decoration != "":
392
+ file_decoration = "-" + file_decoration.lower()
393
+
394
+ file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
395
+
396
+ if extension == 'png' and opts.enable_pnginfo and info is not None:
397
+ pnginfo = PngImagePlugin.PngInfo()
398
+
399
+ if existing_info is not None:
400
+ for k, v in existing_info.items():
401
+ pnginfo.add_text(k, str(v))
402
+
403
+ pnginfo.add_text(pnginfo_section_name, info)
404
+ else:
405
+ pnginfo = None
406
+
407
+ if save_to_dirs is None:
408
+ save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
409
+
410
+ if save_to_dirs:
411
+ dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
412
+ path = os.path.join(path, dirname)
413
+
414
+ os.makedirs(path, exist_ok=True)
415
+
416
+ if forced_filename is None:
417
+ basecount = get_next_sequence_number(path, basename)
418
+ fullfn = "a.png"
419
+ fullfn_without_extension = "a"
420
+ for i in range(500):
421
+ fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
422
+ fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
423
+ fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
424
+ if not os.path.exists(fullfn):
425
+ break
426
+ else:
427
+ fullfn = os.path.join(path, f"{forced_filename}.{extension}")
428
+ fullfn_without_extension = os.path.join(path, forced_filename)
429
+
430
+ def exif_bytes():
431
+ return piexif.dump({
432
+ "Exif": {
433
+ piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
434
+ },
435
+ })
436
+
437
+ if extension.lower() in ("jpg", "jpeg", "webp"):
438
+ image.save(fullfn, quality=opts.jpeg_quality)
439
+ if opts.enable_pnginfo and info is not None:
440
+ piexif.insert(exif_bytes(), fullfn)
441
+ else:
442
+ image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
443
+
444
+ target_side_length = 4000
445
+ oversize = image.width > target_side_length or image.height > target_side_length
446
+ if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
447
+ ratio = image.width / image.height
448
+
449
+ if oversize and ratio > 1:
450
+ image = image.resize((target_side_length, image.height * target_side_length // image.width), LANCZOS)
451
+ elif oversize:
452
+ image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
453
+
454
+ image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality)
455
+ if opts.enable_pnginfo and info is not None:
456
+ piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
457
+
458
+ if opts.save_txt and info is not None:
459
+ txt_fullfn = f"{fullfn_without_extension}.txt"
460
+ with open(txt_fullfn, "w", encoding="utf8") as file:
461
+ file.write(info + "\n")
462
+ else:
463
+ txt_fullfn = None
464
+
465
+ return fullfn, txt_fullfn
modules/img2img.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import sys
4
+ import traceback
5
+
6
+ import numpy as np
7
+ from PIL import Image, ImageOps, ImageChops
8
+
9
+ from modules import devices
10
+ from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
11
+ from modules.shared import opts, state
12
+ import modules.shared as shared
13
+ import modules.processing as processing
14
+ from modules.ui import plaintext_to_html
15
+ import modules.images as images
16
+ import modules.scripts
17
+
18
+
19
+ def process_batch(p, input_dir, output_dir, args):
20
+ processing.fix_seed(p)
21
+
22
+ images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
23
+
24
+ print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
25
+
26
+ save_normally = output_dir == ''
27
+
28
+ p.do_not_save_grid = True
29
+ p.do_not_save_samples = not save_normally
30
+
31
+ state.job_count = len(images) * p.n_iter
32
+
33
+ for i, image in enumerate(images):
34
+ state.job = f"{i+1} out of {len(images)}"
35
+ if state.skipped:
36
+ state.skipped = False
37
+
38
+ if state.interrupted:
39
+ break
40
+
41
+ img = Image.open(image)
42
+ p.init_images = [img] * p.batch_size
43
+
44
+ proc = modules.scripts.scripts_img2img.run(p, *args)
45
+ if proc is None:
46
+ proc = process_images(p)
47
+
48
+ for n, processed_image in enumerate(proc.images):
49
+ filename = os.path.basename(image)
50
+
51
+ if n > 0:
52
+ left, right = os.path.splitext(filename)
53
+ filename = f"{left}-{n}{right}"
54
+
55
+ if not save_normally:
56
+ processed_image.save(os.path.join(output_dir, filename))
57
+
58
+
59
+ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
60
+ is_inpaint = mode == 1
61
+ is_batch = mode == 2
62
+
63
+ if is_inpaint:
64
+ if mask_mode == 0:
65
+ image = init_img_with_mask['image']
66
+ mask = init_img_with_mask['mask']
67
+ alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
68
+ mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
69
+ image = image.convert('RGB')
70
+ else:
71
+ image = init_img_inpaint
72
+ mask = init_mask_inpaint
73
+ else:
74
+ image = init_img
75
+ mask = None
76
+
77
+ assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
78
+
79
+ p = StableDiffusionProcessingImg2Img(
80
+ sd_model=shared.sd_model,
81
+ outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
82
+ outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
83
+ prompt=prompt,
84
+ negative_prompt=negative_prompt,
85
+ styles=[prompt_style, prompt_style2],
86
+ seed=seed,
87
+ subseed=subseed,
88
+ subseed_strength=subseed_strength,
89
+ seed_resize_from_h=seed_resize_from_h,
90
+ seed_resize_from_w=seed_resize_from_w,
91
+ seed_enable_extras=seed_enable_extras,
92
+ sampler_index=sampler_index,
93
+ batch_size=batch_size,
94
+ n_iter=n_iter,
95
+ steps=steps,
96
+ cfg_scale=cfg_scale,
97
+ width=width,
98
+ height=height,
99
+ restore_faces=restore_faces,
100
+ tiling=tiling,
101
+ init_images=[image],
102
+ mask=mask,
103
+ mask_blur=mask_blur,
104
+ inpainting_fill=inpainting_fill,
105
+ resize_mode=resize_mode,
106
+ denoising_strength=denoising_strength,
107
+ inpaint_full_res=inpaint_full_res,
108
+ inpaint_full_res_padding=inpaint_full_res_padding,
109
+ inpainting_mask_invert=inpainting_mask_invert,
110
+ )
111
+
112
+ if shared.cmd_opts.enable_console_prompts:
113
+ print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
114
+
115
+ p.extra_generation_params["Mask blur"] = mask_blur
116
+
117
+ if is_batch:
118
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
119
+
120
+ process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args)
121
+
122
+ processed = Processed(p, [], p.seed, "")
123
+ else:
124
+ processed = modules.scripts.scripts_img2img.run(p, *args)
125
+ if processed is None:
126
+ processed = process_images(p)
127
+
128
+ shared.total_tqdm.clear()
129
+
130
+ generation_info_js = processed.js()
131
+ if opts.samples_log_stdout:
132
+ print(generation_info_js)
133
+
134
+ if opts.do_not_show_images:
135
+ processed.images = []
136
+
137
+ return processed.images, generation_info_js, plaintext_to_html(processed.info)
modules/interrogate.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import os
3
+ import sys
4
+ import traceback
5
+ from collections import namedtuple
6
+ import re
7
+
8
+ import torch
9
+
10
+ from torchvision import transforms
11
+ from torchvision.transforms.functional import InterpolationMode
12
+
13
+ import modules.shared as shared
14
+ from modules import devices, paths, lowvram
15
+
16
+ blip_image_eval_size = 384
17
+ blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
18
+ clip_model_name = 'ViT-L/14'
19
+
20
+ Category = namedtuple("Category", ["name", "topn", "items"])
21
+
22
+ re_topn = re.compile(r"\.top(\d+)\.")
23
+
24
+
25
+ class InterrogateModels:
26
+ blip_model = None
27
+ clip_model = None
28
+ clip_preprocess = None
29
+ categories = None
30
+ dtype = None
31
+
32
+ def __init__(self, content_dir):
33
+ self.categories = []
34
+
35
+ if os.path.exists(content_dir):
36
+ for filename in os.listdir(content_dir):
37
+ m = re_topn.search(filename)
38
+ topn = 1 if m is None else int(m.group(1))
39
+
40
+ with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
41
+ lines = [x.strip() for x in file.readlines()]
42
+
43
+ self.categories.append(Category(name=filename, topn=topn, items=lines))
44
+
45
+ def load_blip_model(self):
46
+ import models.blip
47
+
48
+ blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
49
+ blip_model.eval()
50
+
51
+ return blip_model
52
+
53
+ def load_clip_model(self):
54
+ import clip
55
+
56
+ model, preprocess = clip.load(clip_model_name)
57
+ model.eval()
58
+ model = model.to(shared.device)
59
+
60
+ return model, preprocess
61
+
62
+ def load(self):
63
+ if self.blip_model is None:
64
+ self.blip_model = self.load_blip_model()
65
+ if not shared.cmd_opts.no_half:
66
+ self.blip_model = self.blip_model.half()
67
+
68
+ self.blip_model = self.blip_model.to(shared.device)
69
+
70
+ if self.clip_model is None:
71
+ self.clip_model, self.clip_preprocess = self.load_clip_model()
72
+ if not shared.cmd_opts.no_half:
73
+ self.clip_model = self.clip_model.half()
74
+
75
+ self.clip_model = self.clip_model.to(shared.device)
76
+
77
+ self.dtype = next(self.clip_model.parameters()).dtype
78
+
79
+ def send_clip_to_ram(self):
80
+ if not shared.opts.interrogate_keep_models_in_memory:
81
+ if self.clip_model is not None:
82
+ self.clip_model = self.clip_model.to(devices.cpu)
83
+
84
+ def send_blip_to_ram(self):
85
+ if not shared.opts.interrogate_keep_models_in_memory:
86
+ if self.blip_model is not None:
87
+ self.blip_model = self.blip_model.to(devices.cpu)
88
+
89
+ def unload(self):
90
+ self.send_clip_to_ram()
91
+ self.send_blip_to_ram()
92
+
93
+ devices.torch_gc()
94
+
95
+ def rank(self, image_features, text_array, top_count=1):
96
+ import clip
97
+
98
+ if shared.opts.interrogate_clip_dict_limit != 0:
99
+ text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
100
+
101
+ top_count = min(top_count, len(text_array))
102
+ text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
103
+ text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
104
+ text_features /= text_features.norm(dim=-1, keepdim=True)
105
+
106
+ similarity = torch.zeros((1, len(text_array))).to(shared.device)
107
+ for i in range(image_features.shape[0]):
108
+ similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
109
+ similarity /= image_features.shape[0]
110
+
111
+ top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
112
+ return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
113
+
114
+ def generate_caption(self, pil_image):
115
+ gpu_image = transforms.Compose([
116
+ transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
117
+ transforms.ToTensor(),
118
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
119
+ ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
120
+
121
+ with torch.no_grad():
122
+ caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
123
+
124
+ return caption[0]
125
+
126
+ def interrogate(self, pil_image, include_ranks=False):
127
+ res = None
128
+
129
+ try:
130
+
131
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
132
+ lowvram.send_everything_to_cpu()
133
+ devices.torch_gc()
134
+
135
+ self.load()
136
+
137
+ caption = self.generate_caption(pil_image)
138
+ self.send_blip_to_ram()
139
+ devices.torch_gc()
140
+
141
+ res = caption
142
+
143
+ clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
144
+
145
+ precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
146
+ with torch.no_grad(), precision_scope("cuda"):
147
+ image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
148
+
149
+ image_features /= image_features.norm(dim=-1, keepdim=True)
150
+
151
+ if shared.opts.interrogate_use_builtin_artists:
152
+ artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
153
+
154
+ res += ", " + artist[0]
155
+
156
+ for name, topn, items in self.categories:
157
+ matches = self.rank(image_features, items, top_count=topn)
158
+ for match, score in matches:
159
+ if include_ranks:
160
+ res += ", " + match
161
+ else:
162
+ res += f", ({match}:{score})"
163
+
164
+ except Exception:
165
+ print(f"Error interrogating", file=sys.stderr)
166
+ print(traceback.format_exc(), file=sys.stderr)
167
+ res += "<error>"
168
+
169
+ self.unload()
170
+
171
+ return res
modules/ldsr_model.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ from basicsr.utils.download_util import load_file_from_url
6
+
7
+ from modules.upscaler import Upscaler, UpscalerData
8
+ from modules.ldsr_model_arch import LDSR
9
+ from modules import shared
10
+
11
+
12
+ class UpscalerLDSR(Upscaler):
13
+ def __init__(self, user_path):
14
+ self.name = "LDSR"
15
+ self.user_path = user_path
16
+ self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
17
+ self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
18
+ super().__init__()
19
+ scaler_data = UpscalerData("LDSR", None, self)
20
+ self.scalers = [scaler_data]
21
+
22
+ def load_model(self, path: str):
23
+ # Remove incorrect project.yaml file if too big
24
+ yaml_path = os.path.join(self.model_path, "project.yaml")
25
+ old_model_path = os.path.join(self.model_path, "model.pth")
26
+ new_model_path = os.path.join(self.model_path, "model.ckpt")
27
+ if os.path.exists(yaml_path):
28
+ statinfo = os.stat(yaml_path)
29
+ if statinfo.st_size >= 10485760:
30
+ print("Removing invalid LDSR YAML file.")
31
+ os.remove(yaml_path)
32
+ if os.path.exists(old_model_path):
33
+ print("Renaming model from model.pth to model.ckpt")
34
+ os.rename(old_model_path, new_model_path)
35
+ model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
36
+ file_name="model.ckpt", progress=True)
37
+ yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
38
+ file_name="project.yaml", progress=True)
39
+
40
+ try:
41
+ return LDSR(model, yaml)
42
+
43
+ except Exception:
44
+ print("Error importing LDSR:", file=sys.stderr)
45
+ print(traceback.format_exc(), file=sys.stderr)
46
+ return None
47
+
48
+ def do_upscale(self, img, path):
49
+ ldsr = self.load_model(path)
50
+ if ldsr is None:
51
+ print("NO LDSR!")
52
+ return img
53
+ ddim_steps = shared.opts.ldsr_steps
54
+ return ldsr.super_resolution(img, ddim_steps, self.scale)
modules/ldsr_model_arch.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import time
3
+ import warnings
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision
8
+ from PIL import Image
9
+ from einops import rearrange, repeat
10
+ from omegaconf import OmegaConf
11
+
12
+ from ldm.models.diffusion.ddim import DDIMSampler
13
+ from ldm.util import instantiate_from_config, ismap
14
+
15
+ warnings.filterwarnings("ignore", category=UserWarning)
16
+
17
+
18
+ # Create LDSR Class
19
+ class LDSR:
20
+ def load_model_from_config(self, half_attention):
21
+ print(f"Loading model from {self.modelPath}")
22
+ pl_sd = torch.load(self.modelPath, map_location="cpu")
23
+ sd = pl_sd["state_dict"]
24
+ config = OmegaConf.load(self.yamlPath)
25
+ model = instantiate_from_config(config.model)
26
+ model.load_state_dict(sd, strict=False)
27
+ model.cuda()
28
+ if half_attention:
29
+ model = model.half()
30
+
31
+ model.eval()
32
+ return {"model": model}
33
+
34
+ def __init__(self, model_path, yaml_path):
35
+ self.modelPath = model_path
36
+ self.yamlPath = yaml_path
37
+
38
+ @staticmethod
39
+ def run(model, selected_path, custom_steps, eta):
40
+ example = get_cond(selected_path)
41
+
42
+ n_runs = 1
43
+ guider = None
44
+ ckwargs = None
45
+ ddim_use_x0_pred = False
46
+ temperature = 1.
47
+ eta = eta
48
+ custom_shape = None
49
+
50
+ height, width = example["image"].shape[1:3]
51
+ split_input = height >= 128 and width >= 128
52
+
53
+ if split_input:
54
+ ks = 128
55
+ stride = 64
56
+ vqf = 4 #
57
+ model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
58
+ "vqf": vqf,
59
+ "patch_distributed_vq": True,
60
+ "tie_braker": False,
61
+ "clip_max_weight": 0.5,
62
+ "clip_min_weight": 0.01,
63
+ "clip_max_tie_weight": 0.5,
64
+ "clip_min_tie_weight": 0.01}
65
+ else:
66
+ if hasattr(model, "split_input_params"):
67
+ delattr(model, "split_input_params")
68
+
69
+ x_t = None
70
+ logs = None
71
+ for n in range(n_runs):
72
+ if custom_shape is not None:
73
+ x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
74
+ x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
75
+
76
+ logs = make_convolutional_sample(example, model,
77
+ custom_steps=custom_steps,
78
+ eta=eta, quantize_x0=False,
79
+ custom_shape=custom_shape,
80
+ temperature=temperature, noise_dropout=0.,
81
+ corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
82
+ ddim_use_x0_pred=ddim_use_x0_pred
83
+ )
84
+ return logs
85
+
86
+ def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
87
+ model = self.load_model_from_config(half_attention)
88
+
89
+ # Run settings
90
+ diffusion_steps = int(steps)
91
+ eta = 1.0
92
+
93
+ down_sample_method = 'Lanczos'
94
+
95
+ gc.collect()
96
+ torch.cuda.empty_cache()
97
+
98
+ im_og = image
99
+ width_og, height_og = im_og.size
100
+ # If we can adjust the max upscale size, then the 4 below should be our variable
101
+ down_sample_rate = target_scale / 4
102
+ wd = width_og * down_sample_rate
103
+ hd = height_og * down_sample_rate
104
+ width_downsampled_pre = int(wd)
105
+ height_downsampled_pre = int(hd)
106
+
107
+ if down_sample_rate != 1:
108
+ print(
109
+ f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
110
+ im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
111
+ else:
112
+ print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
113
+ logs = self.run(model["model"], im_og, diffusion_steps, eta)
114
+
115
+ sample = logs["sample"]
116
+ sample = sample.detach().cpu()
117
+ sample = torch.clamp(sample, -1., 1.)
118
+ sample = (sample + 1.) / 2. * 255
119
+ sample = sample.numpy().astype(np.uint8)
120
+ sample = np.transpose(sample, (0, 2, 3, 1))
121
+ a = Image.fromarray(sample[0])
122
+
123
+ del model
124
+ gc.collect()
125
+ torch.cuda.empty_cache()
126
+ return a
127
+
128
+
129
+ def get_cond(selected_path):
130
+ example = dict()
131
+ up_f = 4
132
+ c = selected_path.convert('RGB')
133
+ c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
134
+ c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
135
+ antialias=True)
136
+ c_up = rearrange(c_up, '1 c h w -> 1 h w c')
137
+ c = rearrange(c, '1 c h w -> 1 h w c')
138
+ c = 2. * c - 1.
139
+
140
+ c = c.to(torch.device("cuda"))
141
+ example["LR_image"] = c
142
+ example["image"] = c_up
143
+
144
+ return example
145
+
146
+
147
+ @torch.no_grad()
148
+ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
149
+ mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
150
+ corrector_kwargs=None, x_t=None
151
+ ):
152
+ ddim = DDIMSampler(model)
153
+ bs = shape[0]
154
+ shape = shape[1:]
155
+ print(f"Sampling with eta = {eta}; steps: {steps}")
156
+ samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
157
+ normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
158
+ mask=mask, x0=x0, temperature=temperature, verbose=False,
159
+ score_corrector=score_corrector,
160
+ corrector_kwargs=corrector_kwargs, x_t=x_t)
161
+
162
+ return samples, intermediates
163
+
164
+
165
+ @torch.no_grad()
166
+ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
167
+ corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
168
+ log = dict()
169
+
170
+ z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
171
+ return_first_stage_outputs=True,
172
+ force_c_encode=not (hasattr(model, 'split_input_params')
173
+ and model.cond_stage_key == 'coordinates_bbox'),
174
+ return_original_cond=True)
175
+
176
+ if custom_shape is not None:
177
+ z = torch.randn(custom_shape)
178
+ print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
179
+
180
+ z0 = None
181
+
182
+ log["input"] = x
183
+ log["reconstruction"] = xrec
184
+
185
+ if ismap(xc):
186
+ log["original_conditioning"] = model.to_rgb(xc)
187
+ if hasattr(model, 'cond_stage_key'):
188
+ log[model.cond_stage_key] = model.to_rgb(xc)
189
+
190
+ else:
191
+ log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
192
+ if model.cond_stage_model:
193
+ log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
194
+ if model.cond_stage_key == 'class_label':
195
+ log[model.cond_stage_key] = xc[model.cond_stage_key]
196
+
197
+ with model.ema_scope("Plotting"):
198
+ t0 = time.time()
199
+
200
+ sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
201
+ eta=eta,
202
+ quantize_x0=quantize_x0, mask=None, x0=z0,
203
+ temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
204
+ x_t=x_T)
205
+ t1 = time.time()
206
+
207
+ if ddim_use_x0_pred:
208
+ sample = intermediates['pred_x0'][-1]
209
+
210
+ x_sample = model.decode_first_stage(sample)
211
+
212
+ try:
213
+ x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
214
+ log["sample_noquant"] = x_sample_noquant
215
+ log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
216
+ except:
217
+ pass
218
+
219
+ log["sample"] = x_sample
220
+ log["time"] = t1 - t0
221
+
222
+ return log
modules/lowvram.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from modules.devices import get_optimal_device
3
+
4
+ module_in_gpu = None
5
+ cpu = torch.device("cpu")
6
+ device = gpu = get_optimal_device()
7
+
8
+
9
+ def send_everything_to_cpu():
10
+ global module_in_gpu
11
+
12
+ if module_in_gpu is not None:
13
+ module_in_gpu.to(cpu)
14
+
15
+ module_in_gpu = None
16
+
17
+
18
+ def setup_for_low_vram(sd_model, use_medvram):
19
+ parents = {}
20
+
21
+ def send_me_to_gpu(module, _):
22
+ """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
23
+ we add this as forward_pre_hook to a lot of modules and this way all but one of them will
24
+ be in CPU
25
+ """
26
+ global module_in_gpu
27
+
28
+ module = parents.get(module, module)
29
+
30
+ if module_in_gpu == module:
31
+ return
32
+
33
+ if module_in_gpu is not None:
34
+ module_in_gpu.to(cpu)
35
+
36
+ module.to(gpu)
37
+ module_in_gpu = module
38
+
39
+ # see below for register_forward_pre_hook;
40
+ # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
41
+ # useless here, and we just replace those methods
42
+ def first_stage_model_encode_wrap(self, encoder, x):
43
+ send_me_to_gpu(self, None)
44
+ return encoder(x)
45
+
46
+ def first_stage_model_decode_wrap(self, decoder, z):
47
+ send_me_to_gpu(self, None)
48
+ return decoder(z)
49
+
50
+ # remove three big modules, cond, first_stage, and unet from the model and then
51
+ # send the model to GPU. Then put modules back. the modules will be in CPU.
52
+ stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
53
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
54
+ sd_model.to(device)
55
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
56
+
57
+ # register hooks for those the first two models
58
+ sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
59
+ sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
60
+ sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
61
+ sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
62
+ parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
63
+
64
+ if use_medvram:
65
+ sd_model.model.register_forward_pre_hook(send_me_to_gpu)
66
+ else:
67
+ diff_model = sd_model.model.diffusion_model
68
+
69
+ # the third remaining model is still too big for 4 GB, so we also do the same for its submodules
70
+ # so that only one of them is in GPU at a time
71
+ stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
72
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
73
+ sd_model.model.to(device)
74
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
75
+
76
+ # install hooks for bits of third model
77
+ diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
78
+ for block in diff_model.input_blocks:
79
+ block.register_forward_pre_hook(send_me_to_gpu)
80
+ diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
81
+ for block in diff_model.output_blocks:
82
+ block.register_forward_pre_hook(send_me_to_gpu)
modules/masking.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFilter, ImageOps
2
+
3
+
4
+ def get_crop_region(mask, pad=0):
5
+ """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
6
+ For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
7
+
8
+ h, w = mask.shape
9
+
10
+ crop_left = 0
11
+ for i in range(w):
12
+ if not (mask[:, i] == 0).all():
13
+ break
14
+ crop_left += 1
15
+
16
+ crop_right = 0
17
+ for i in reversed(range(w)):
18
+ if not (mask[:, i] == 0).all():
19
+ break
20
+ crop_right += 1
21
+
22
+ crop_top = 0
23
+ for i in range(h):
24
+ if not (mask[i] == 0).all():
25
+ break
26
+ crop_top += 1
27
+
28
+ crop_bottom = 0
29
+ for i in reversed(range(h)):
30
+ if not (mask[i] == 0).all():
31
+ break
32
+ crop_bottom += 1
33
+
34
+ return (
35
+ int(max(crop_left-pad, 0)),
36
+ int(max(crop_top-pad, 0)),
37
+ int(min(w - crop_right + pad, w)),
38
+ int(min(h - crop_bottom + pad, h))
39
+ )
40
+
41
+
42
+ def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
43
+ """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
44
+ for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
45
+
46
+ x1, y1, x2, y2 = crop_region
47
+
48
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
49
+ ratio_processing = processing_width / processing_height
50
+
51
+ if ratio_crop_region > ratio_processing:
52
+ desired_height = (x2 - x1) * ratio_processing
53
+ desired_height_diff = int(desired_height - (y2-y1))
54
+ y1 -= desired_height_diff//2
55
+ y2 += desired_height_diff - desired_height_diff//2
56
+ if y2 >= image_height:
57
+ diff = y2 - image_height
58
+ y2 -= diff
59
+ y1 -= diff
60
+ if y1 < 0:
61
+ y2 -= y1
62
+ y1 -= y1
63
+ if y2 >= image_height:
64
+ y2 = image_height
65
+ else:
66
+ desired_width = (y2 - y1) * ratio_processing
67
+ desired_width_diff = int(desired_width - (x2-x1))
68
+ x1 -= desired_width_diff//2
69
+ x2 += desired_width_diff - desired_width_diff//2
70
+ if x2 >= image_width:
71
+ diff = x2 - image_width
72
+ x2 -= diff
73
+ x1 -= diff
74
+ if x1 < 0:
75
+ x2 -= x1
76
+ x1 -= x1
77
+ if x2 >= image_width:
78
+ x2 = image_width
79
+
80
+ return x1, y1, x2, y2
81
+
82
+
83
+ def fill(image, mask):
84
+ """fills masked regions with colors from image using blur. Not extremely effective."""
85
+
86
+ image_mod = Image.new('RGBA', (image.width, image.height))
87
+
88
+ image_masked = Image.new('RGBa', (image.width, image.height))
89
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
90
+
91
+ image_masked = image_masked.convert('RGBa')
92
+
93
+ for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
94
+ blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
95
+ for _ in range(repeats):
96
+ image_mod.alpha_composite(blurred)
97
+
98
+ return image_mod.convert("RGB")
99
+
modules/memmon.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ from collections import defaultdict
4
+
5
+ import torch
6
+
7
+
8
+ class MemUsageMonitor(threading.Thread):
9
+ run_flag = None
10
+ device = None
11
+ disabled = False
12
+ opts = None
13
+ data = None
14
+
15
+ def __init__(self, name, device, opts):
16
+ threading.Thread.__init__(self)
17
+ self.name = name
18
+ self.device = device
19
+ self.opts = opts
20
+
21
+ self.daemon = True
22
+ self.run_flag = threading.Event()
23
+ self.data = defaultdict(int)
24
+
25
+ try:
26
+ torch.cuda.mem_get_info()
27
+ torch.cuda.memory_stats(self.device)
28
+ except Exception as e: # AMD or whatever
29
+ print(f"Warning: caught exception '{e}', memory monitor disabled")
30
+ self.disabled = True
31
+
32
+ def run(self):
33
+ if self.disabled:
34
+ return
35
+
36
+ while True:
37
+ self.run_flag.wait()
38
+
39
+ torch.cuda.reset_peak_memory_stats()
40
+ self.data.clear()
41
+
42
+ if self.opts.memmon_poll_rate <= 0:
43
+ self.run_flag.clear()
44
+ continue
45
+
46
+ self.data["min_free"] = torch.cuda.mem_get_info()[0]
47
+
48
+ while self.run_flag.is_set():
49
+ free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
50
+ self.data["min_free"] = min(self.data["min_free"], free)
51
+
52
+ time.sleep(1 / self.opts.memmon_poll_rate)
53
+
54
+ def dump_debug(self):
55
+ print(self, 'recorded data:')
56
+ for k, v in self.read().items():
57
+ print(k, -(v // -(1024 ** 2)))
58
+
59
+ print(self, 'raw torch memory stats:')
60
+ tm = torch.cuda.memory_stats(self.device)
61
+ for k, v in tm.items():
62
+ if 'bytes' not in k:
63
+ continue
64
+ print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
65
+
66
+ print(torch.cuda.memory_summary())
67
+
68
+ def monitor(self):
69
+ self.run_flag.set()
70
+
71
+ def read(self):
72
+ if not self.disabled:
73
+ free, total = torch.cuda.mem_get_info()
74
+ self.data["total"] = total
75
+
76
+ torch_stats = torch.cuda.memory_stats(self.device)
77
+ self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
78
+ self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
79
+ self.data["system_peak"] = total - self.data["min_free"]
80
+
81
+ return self.data
82
+
83
+ def stop(self):
84
+ self.run_flag.clear()
85
+ return self.read()
modules/modelloader.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import shutil
4
+ import importlib
5
+ from urllib.parse import urlparse
6
+
7
+ from basicsr.utils.download_util import load_file_from_url
8
+ from modules import shared
9
+ from modules.upscaler import Upscaler
10
+ from modules.paths import script_path, models_path
11
+
12
+
13
+ def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
14
+ """
15
+ A one-and done loader to try finding the desired models in specified directories.
16
+
17
+ @param download_name: Specify to download from model_url immediately.
18
+ @param model_url: If no other models are found, this will be downloaded on upscale.
19
+ @param model_path: The location to store/find models in.
20
+ @param command_path: A command-line argument to search for models in first.
21
+ @param ext_filter: An optional list of filename extensions to filter by
22
+ @return: A list of paths containing the desired model(s)
23
+ """
24
+ output = []
25
+
26
+ if ext_filter is None:
27
+ ext_filter = []
28
+
29
+ try:
30
+ places = []
31
+
32
+ if command_path is not None and command_path != model_path:
33
+ pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
34
+ if os.path.exists(pretrained_path):
35
+ print(f"Appending path: {pretrained_path}")
36
+ places.append(pretrained_path)
37
+ elif os.path.exists(command_path):
38
+ places.append(command_path)
39
+
40
+ places.append(model_path)
41
+
42
+ for place in places:
43
+ if os.path.exists(place):
44
+ for file in glob.iglob(place + '**/**', recursive=True):
45
+ full_path = file
46
+ if os.path.isdir(full_path):
47
+ continue
48
+ if len(ext_filter) != 0:
49
+ model_name, extension = os.path.splitext(file)
50
+ if extension not in ext_filter:
51
+ continue
52
+ if file not in output:
53
+ output.append(full_path)
54
+
55
+ if model_url is not None and len(output) == 0:
56
+ if download_name is not None:
57
+ dl = load_file_from_url(model_url, model_path, True, download_name)
58
+ output.append(dl)
59
+ else:
60
+ output.append(model_url)
61
+
62
+ except Exception:
63
+ pass
64
+
65
+ return output
66
+
67
+
68
+ def friendly_name(file: str):
69
+ if "http" in file:
70
+ file = urlparse(file).path
71
+
72
+ file = os.path.basename(file)
73
+ model_name, extension = os.path.splitext(file)
74
+ return model_name
75
+
76
+
77
+ def cleanup_models():
78
+ # This code could probably be more efficient if we used a tuple list or something to store the src/destinations
79
+ # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
80
+ # somehow auto-register and just do these things...
81
+ root_path = script_path
82
+ src_path = models_path
83
+ dest_path = os.path.join(models_path, "Stable-diffusion")
84
+ move_files(src_path, dest_path, ".ckpt")
85
+ src_path = os.path.join(root_path, "ESRGAN")
86
+ dest_path = os.path.join(models_path, "ESRGAN")
87
+ move_files(src_path, dest_path)
88
+ src_path = os.path.join(root_path, "gfpgan")
89
+ dest_path = os.path.join(models_path, "GFPGAN")
90
+ move_files(src_path, dest_path)
91
+ src_path = os.path.join(root_path, "SwinIR")
92
+ dest_path = os.path.join(models_path, "SwinIR")
93
+ move_files(src_path, dest_path)
94
+ src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
95
+ dest_path = os.path.join(models_path, "LDSR")
96
+ move_files(src_path, dest_path)
97
+
98
+
99
+ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
100
+ try:
101
+ if not os.path.exists(dest_path):
102
+ os.makedirs(dest_path)
103
+ if os.path.exists(src_path):
104
+ for file in os.listdir(src_path):
105
+ fullpath = os.path.join(src_path, file)
106
+ if os.path.isfile(fullpath):
107
+ if ext_filter is not None:
108
+ if ext_filter not in file:
109
+ continue
110
+ print(f"Moving {file} from {src_path} to {dest_path}.")
111
+ try:
112
+ shutil.move(fullpath, dest_path)
113
+ except:
114
+ pass
115
+ if len(os.listdir(src_path)) == 0:
116
+ print(f"Removing empty folder: {src_path}")
117
+ shutil.rmtree(src_path, True)
118
+ except:
119
+ pass
120
+
121
+
122
+ def load_upscalers():
123
+ sd = shared.script_path
124
+ # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
125
+ # so we'll try to import any _model.py files before looking in __subclasses__
126
+ modules_dir = os.path.join(sd, "modules")
127
+ for file in os.listdir(modules_dir):
128
+ if "_model.py" in file:
129
+ model_name = file.replace("_model.py", "")
130
+ full_model = f"modules.{model_name}_model"
131
+ try:
132
+ importlib.import_module(full_model)
133
+ except:
134
+ pass
135
+ datas = []
136
+ c_o = vars(shared.cmd_opts)
137
+ for cls in Upscaler.__subclasses__():
138
+ name = cls.__name__
139
+ module_name = cls.__module__
140
+ module = importlib.import_module(module_name)
141
+ class_ = getattr(module, name)
142
+ cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
143
+ opt_string = None
144
+ try:
145
+ if cmd_name in c_o:
146
+ opt_string = c_o[cmd_name]
147
+ except:
148
+ pass
149
+ scaler = class_(opt_string)
150
+ for child in scaler.scalers:
151
+ datas.append(child)
152
+
153
+ shared.sd_upscalers = datas
modules/ngrok.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pyngrok import ngrok, conf, exception
2
+
3
+
4
+ def connect(token, port):
5
+ if token == None:
6
+ token = 'None'
7
+ conf.get_default().auth_token = token
8
+ try:
9
+ public_url = ngrok.connect(port).public_url
10
+ except exception.PyngrokNgrokError:
11
+ print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
12
+ f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
13
+ else:
14
+ print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
15
+ 'You can use this link after the launch is complete.')
modules/paths.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import modules.safe
5
+
6
+ script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
7
+ models_path = os.path.join(script_path, "models")
8
+ sys.path.insert(0, script_path)
9
+
10
+ # search for directory of stable diffusion in following places
11
+ sd_path = None
12
+ possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
13
+ for possible_sd_path in possible_sd_paths:
14
+ if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
15
+ sd_path = os.path.abspath(possible_sd_path)
16
+ break
17
+
18
+ assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
19
+
20
+ path_dirs = [
21
+ (sd_path, 'ldm', 'Stable Diffusion', []),
22
+ (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
23
+ (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
24
+ (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
25
+ (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
26
+ ]
27
+
28
+ paths = {}
29
+
30
+ for d, must_exist, what, options in path_dirs:
31
+ must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
32
+ if not os.path.exists(must_exist_path):
33
+ print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
34
+ else:
35
+ d = os.path.abspath(d)
36
+ if "atstart" in options:
37
+ sys.path.insert(0, d)
38
+ else:
39
+ sys.path.append(d)
40
+ paths[what] = d
modules/processing.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import sys
5
+
6
+ import torch
7
+ import numpy as np
8
+ from PIL import Image, ImageFilter, ImageOps
9
+ import random
10
+ import cv2
11
+ from skimage import exposure
12
+
13
+ import modules.sd_hijack
14
+ from modules import devices, prompt_parser, masking, sd_samplers, lowvram
15
+ from modules.sd_hijack import model_hijack
16
+ from modules.shared import opts, cmd_opts, state
17
+ import modules.shared as shared
18
+ import modules.face_restoration
19
+ import modules.images as images
20
+ import modules.styles
21
+ import logging
22
+
23
+
24
+ # some of those options should not be changed at all because they would break the model, so I removed them from options.
25
+ opt_C = 4
26
+ opt_f = 8
27
+
28
+
29
+ def setup_color_correction(image):
30
+ logging.info("Calibrating color correction.")
31
+ correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
32
+ return correction_target
33
+
34
+
35
+ def apply_color_correction(correction, image):
36
+ logging.info("Applying color correction.")
37
+ image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
38
+ cv2.cvtColor(
39
+ np.asarray(image),
40
+ cv2.COLOR_RGB2LAB
41
+ ),
42
+ correction,
43
+ channel_axis=2
44
+ ), cv2.COLOR_LAB2RGB).astype("uint8"))
45
+
46
+ return image
47
+
48
+
49
+ def get_correct_sampler(p):
50
+ if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
51
+ return sd_samplers.samplers
52
+ elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
53
+ return sd_samplers.samplers_for_img2img
54
+
55
+ class StableDiffusionProcessing:
56
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
57
+ self.sd_model = sd_model
58
+ self.outpath_samples: str = outpath_samples
59
+ self.outpath_grids: str = outpath_grids
60
+ self.prompt: str = prompt
61
+ self.prompt_for_display: str = None
62
+ self.negative_prompt: str = (negative_prompt or "")
63
+ self.styles: list = styles or []
64
+ self.seed: int = seed
65
+ self.subseed: int = subseed
66
+ self.subseed_strength: float = subseed_strength
67
+ self.seed_resize_from_h: int = seed_resize_from_h
68
+ self.seed_resize_from_w: int = seed_resize_from_w
69
+ self.sampler_index: int = sampler_index
70
+ self.batch_size: int = batch_size
71
+ self.n_iter: int = n_iter
72
+ self.steps: int = steps
73
+ self.cfg_scale: float = cfg_scale
74
+ self.width: int = width
75
+ self.height: int = height
76
+ self.restore_faces: bool = restore_faces
77
+ self.tiling: bool = tiling
78
+ self.do_not_save_samples: bool = do_not_save_samples
79
+ self.do_not_save_grid: bool = do_not_save_grid
80
+ self.extra_generation_params: dict = extra_generation_params or {}
81
+ self.overlay_images = overlay_images
82
+ self.eta = eta
83
+ self.paste_to = None
84
+ self.color_corrections = None
85
+ self.denoising_strength: float = 0
86
+ self.sampler_noise_scheduler_override = None
87
+ self.ddim_discretize = opts.ddim_discretize
88
+ self.s_churn = opts.s_churn
89
+ self.s_tmin = opts.s_tmin
90
+ self.s_tmax = float('inf') # not representable as a standard ui option
91
+ self.s_noise = opts.s_noise
92
+
93
+ if not seed_enable_extras:
94
+ self.subseed = -1
95
+ self.subseed_strength = 0
96
+ self.seed_resize_from_h = 0
97
+ self.seed_resize_from_w = 0
98
+
99
+ def init(self, all_prompts, all_seeds, all_subseeds):
100
+ pass
101
+
102
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
103
+ raise NotImplementedError()
104
+
105
+
106
+ class Processed:
107
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
108
+ self.images = images_list
109
+ self.prompt = p.prompt
110
+ self.negative_prompt = p.negative_prompt
111
+ self.seed = seed
112
+ self.subseed = subseed
113
+ self.subseed_strength = p.subseed_strength
114
+ self.info = info
115
+ self.width = p.width
116
+ self.height = p.height
117
+ self.sampler_index = p.sampler_index
118
+ self.sampler = sd_samplers.samplers[p.sampler_index].name
119
+ self.cfg_scale = p.cfg_scale
120
+ self.steps = p.steps
121
+ self.batch_size = p.batch_size
122
+ self.restore_faces = p.restore_faces
123
+ self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
124
+ self.sd_model_hash = shared.sd_model.sd_model_hash
125
+ self.seed_resize_from_w = p.seed_resize_from_w
126
+ self.seed_resize_from_h = p.seed_resize_from_h
127
+ self.denoising_strength = getattr(p, 'denoising_strength', None)
128
+ self.extra_generation_params = p.extra_generation_params
129
+ self.index_of_first_image = index_of_first_image
130
+ self.styles = p.styles
131
+ self.job_timestamp = state.job_timestamp
132
+ self.clip_skip = opts.CLIP_stop_at_last_layers
133
+
134
+ self.eta = p.eta
135
+ self.ddim_discretize = p.ddim_discretize
136
+ self.s_churn = p.s_churn
137
+ self.s_tmin = p.s_tmin
138
+ self.s_tmax = p.s_tmax
139
+ self.s_noise = p.s_noise
140
+ self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
141
+ self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
142
+ self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
143
+ self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
144
+ self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
145
+
146
+ self.all_prompts = all_prompts or [self.prompt]
147
+ self.all_seeds = all_seeds or [self.seed]
148
+ self.all_subseeds = all_subseeds or [self.subseed]
149
+ self.infotexts = infotexts or [info]
150
+
151
+ def js(self):
152
+ obj = {
153
+ "prompt": self.prompt,
154
+ "all_prompts": self.all_prompts,
155
+ "negative_prompt": self.negative_prompt,
156
+ "seed": self.seed,
157
+ "all_seeds": self.all_seeds,
158
+ "subseed": self.subseed,
159
+ "all_subseeds": self.all_subseeds,
160
+ "subseed_strength": self.subseed_strength,
161
+ "width": self.width,
162
+ "height": self.height,
163
+ "sampler_index": self.sampler_index,
164
+ "sampler": self.sampler,
165
+ "cfg_scale": self.cfg_scale,
166
+ "steps": self.steps,
167
+ "batch_size": self.batch_size,
168
+ "restore_faces": self.restore_faces,
169
+ "face_restoration_model": self.face_restoration_model,
170
+ "sd_model_hash": self.sd_model_hash,
171
+ "seed_resize_from_w": self.seed_resize_from_w,
172
+ "seed_resize_from_h": self.seed_resize_from_h,
173
+ "denoising_strength": self.denoising_strength,
174
+ "extra_generation_params": self.extra_generation_params,
175
+ "index_of_first_image": self.index_of_first_image,
176
+ "infotexts": self.infotexts,
177
+ "styles": self.styles,
178
+ "job_timestamp": self.job_timestamp,
179
+ "clip_skip": self.clip_skip,
180
+ }
181
+
182
+ return json.dumps(obj)
183
+
184
+ def infotext(self, p: StableDiffusionProcessing, index):
185
+ return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
186
+
187
+
188
+ # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
189
+ def slerp(val, low, high):
190
+ low_norm = low/torch.norm(low, dim=1, keepdim=True)
191
+ high_norm = high/torch.norm(high, dim=1, keepdim=True)
192
+ dot = (low_norm*high_norm).sum(1)
193
+
194
+ if dot.mean() > 0.9995:
195
+ return low * val + high * (1 - val)
196
+
197
+ omega = torch.acos(dot)
198
+ so = torch.sin(omega)
199
+ res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
200
+ return res
201
+
202
+
203
+ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
204
+ xs = []
205
+
206
+ # if we have multiple seeds, this means we are working with batch size>1; this then
207
+ # enables the generation of additional tensors with noise that the sampler will use during its processing.
208
+ # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
209
+ # produce the same images as with two batches [100], [101].
210
+ if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
211
+ sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
212
+ else:
213
+ sampler_noises = None
214
+
215
+ for i, seed in enumerate(seeds):
216
+ noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
217
+
218
+ subnoise = None
219
+ if subseeds is not None:
220
+ subseed = 0 if i >= len(subseeds) else subseeds[i]
221
+
222
+ subnoise = devices.randn(subseed, noise_shape)
223
+
224
+ # randn results depend on device; gpu and cpu get different results for same seed;
225
+ # the way I see it, it's better to do this on CPU, so that everyone gets same result;
226
+ # but the original script had it like this, so I do not dare change it for now because
227
+ # it will break everyone's seeds.
228
+ noise = devices.randn(seed, noise_shape)
229
+
230
+ if subnoise is not None:
231
+ noise = slerp(subseed_strength, noise, subnoise)
232
+
233
+ if noise_shape != shape:
234
+ x = devices.randn(seed, shape)
235
+ dx = (shape[2] - noise_shape[2]) // 2
236
+ dy = (shape[1] - noise_shape[1]) // 2
237
+ w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
238
+ h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
239
+ tx = 0 if dx < 0 else dx
240
+ ty = 0 if dy < 0 else dy
241
+ dx = max(-dx, 0)
242
+ dy = max(-dy, 0)
243
+
244
+ x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
245
+ noise = x
246
+
247
+ if sampler_noises is not None:
248
+ cnt = p.sampler.number_of_needed_noises(p)
249
+
250
+ if opts.eta_noise_seed_delta > 0:
251
+ torch.manual_seed(seed + opts.eta_noise_seed_delta)
252
+
253
+ for j in range(cnt):
254
+ sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
255
+
256
+ xs.append(noise)
257
+
258
+ if sampler_noises is not None:
259
+ p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
260
+
261
+ x = torch.stack(xs).to(shared.device)
262
+ return x
263
+
264
+
265
+ def decode_first_stage(model, x):
266
+ with devices.autocast(disable=x.dtype == devices.dtype_vae):
267
+ x = model.decode_first_stage(x)
268
+
269
+ return x
270
+
271
+
272
+ def get_fixed_seed(seed):
273
+ if seed is None or seed == '' or seed == -1:
274
+ return int(random.randrange(4294967294))
275
+
276
+ return seed
277
+
278
+
279
+ def fix_seed(p):
280
+ p.seed = get_fixed_seed(p.seed)
281
+ p.subseed = get_fixed_seed(p.subseed)
282
+
283
+
284
+ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
285
+ index = position_in_batch + iteration * p.batch_size
286
+
287
+ clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
288
+
289
+ generation_params = {
290
+ "Steps": p.steps,
291
+ "Sampler": get_correct_sampler(p)[p.sampler_index].name,
292
+ "CFG scale": p.cfg_scale,
293
+ "Seed": all_seeds[index],
294
+ "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
295
+ "Size": f"{p.width}x{p.height}",
296
+ "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
297
+ "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
298
+ "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
299
+ "Batch size": (None if p.batch_size < 2 else p.batch_size),
300
+ "Batch pos": (None if p.batch_size < 2 else position_in_batch),
301
+ "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
302
+ "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
303
+ "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
304
+ "Denoising strength": getattr(p, 'denoising_strength', None),
305
+ "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
306
+ "Clip skip": None if clip_skip <= 1 else clip_skip,
307
+ "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
308
+ }
309
+
310
+ generation_params.update(p.extra_generation_params)
311
+
312
+ generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
313
+
314
+ negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
315
+
316
+ return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
317
+
318
+
319
+ def process_images(p: StableDiffusionProcessing) -> Processed:
320
+ """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
321
+
322
+ if type(p.prompt) == list:
323
+ assert(len(p.prompt) > 0)
324
+ else:
325
+ assert p.prompt is not None
326
+
327
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
328
+ processed = Processed(p, [], p.seed, "")
329
+ file.write(processed.infotext(p, 0))
330
+
331
+ devices.torch_gc()
332
+
333
+ seed = get_fixed_seed(p.seed)
334
+ subseed = get_fixed_seed(p.subseed)
335
+
336
+ if p.outpath_samples is not None:
337
+ os.makedirs(p.outpath_samples, exist_ok=True)
338
+
339
+ if p.outpath_grids is not None:
340
+ os.makedirs(p.outpath_grids, exist_ok=True)
341
+
342
+ modules.sd_hijack.model_hijack.apply_circular(p.tiling)
343
+ modules.sd_hijack.model_hijack.clear_comments()
344
+
345
+ comments = {}
346
+
347
+ shared.prompt_styles.apply_styles(p)
348
+
349
+ if type(p.prompt) == list:
350
+ all_prompts = p.prompt
351
+ else:
352
+ all_prompts = p.batch_size * p.n_iter * [p.prompt]
353
+
354
+ if type(seed) == list:
355
+ all_seeds = seed
356
+ else:
357
+ all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
358
+
359
+ if type(subseed) == list:
360
+ all_subseeds = subseed
361
+ else:
362
+ all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
363
+
364
+ def infotext(iteration=0, position_in_batch=0):
365
+ return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
366
+
367
+ if os.path.exists(cmd_opts.embeddings_dir):
368
+ model_hijack.embedding_db.load_textual_inversion_embeddings()
369
+
370
+ infotexts = []
371
+ output_images = []
372
+
373
+ with torch.no_grad(), p.sd_model.ema_scope():
374
+ with devices.autocast():
375
+ p.init(all_prompts, all_seeds, all_subseeds)
376
+
377
+ if state.job_count == -1:
378
+ state.job_count = p.n_iter
379
+
380
+ for n in range(p.n_iter):
381
+ if state.skipped:
382
+ state.skipped = False
383
+
384
+ if state.interrupted:
385
+ break
386
+
387
+ prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
388
+ seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
389
+ subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
390
+
391
+ if (len(prompts) == 0):
392
+ break
393
+
394
+ #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
395
+ #c = p.sd_model.get_learned_conditioning(prompts)
396
+ with devices.autocast():
397
+ uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
398
+ c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
399
+
400
+ if len(model_hijack.comments) > 0:
401
+ for comment in model_hijack.comments:
402
+ comments[comment] = 1
403
+
404
+ if p.n_iter > 1:
405
+ shared.state.job = f"Batch {n+1} out of {p.n_iter}"
406
+
407
+ with devices.autocast():
408
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
409
+
410
+ if state.interrupted or state.skipped:
411
+
412
+ # if we are interrupted, sample returns just noise
413
+ # use the image collected previously in sampler loop
414
+ samples_ddim = shared.state.current_latent
415
+
416
+ samples_ddim = samples_ddim.to(devices.dtype_vae)
417
+ x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
418
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
419
+
420
+ del samples_ddim
421
+
422
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
423
+ lowvram.send_everything_to_cpu()
424
+
425
+ devices.torch_gc()
426
+
427
+ if opts.filter_nsfw:
428
+ import modules.safety as safety
429
+ x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
430
+
431
+ for i, x_sample in enumerate(x_samples_ddim):
432
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
433
+ x_sample = x_sample.astype(np.uint8)
434
+
435
+ if p.restore_faces:
436
+ if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
437
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
438
+
439
+ devices.torch_gc()
440
+
441
+ x_sample = modules.face_restoration.restore_faces(x_sample)
442
+ devices.torch_gc()
443
+
444
+ image = Image.fromarray(x_sample)
445
+
446
+ if p.color_corrections is not None and i < len(p.color_corrections):
447
+ if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
448
+ images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
449
+ image = apply_color_correction(p.color_corrections[i], image)
450
+
451
+ if p.overlay_images is not None and i < len(p.overlay_images):
452
+ overlay = p.overlay_images[i]
453
+
454
+ if p.paste_to is not None:
455
+ x, y, w, h = p.paste_to
456
+ base_image = Image.new('RGBA', (overlay.width, overlay.height))
457
+ image = images.resize_image(1, image, w, h)
458
+ base_image.paste(image, (x, y))
459
+ image = base_image
460
+
461
+ image = image.convert('RGBA')
462
+ image.alpha_composite(overlay)
463
+ image = image.convert('RGB')
464
+
465
+ if opts.samples_save and not p.do_not_save_samples:
466
+ images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
467
+
468
+ text = infotext(n, i)
469
+ infotexts.append(text)
470
+ if opts.enable_pnginfo:
471
+ image.info["parameters"] = text
472
+ output_images.append(image)
473
+
474
+ del x_samples_ddim
475
+
476
+ devices.torch_gc()
477
+
478
+ state.nextjob()
479
+
480
+ p.color_corrections = None
481
+
482
+ index_of_first_image = 0
483
+ unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
484
+ if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
485
+ grid = images.image_grid(output_images, p.batch_size)
486
+
487
+ if opts.return_grid:
488
+ text = infotext()
489
+ infotexts.insert(0, text)
490
+ if opts.enable_pnginfo:
491
+ grid.info["parameters"] = text
492
+ output_images.insert(0, grid)
493
+ index_of_first_image = 1
494
+
495
+ if opts.grid_save:
496
+ images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
497
+
498
+ devices.torch_gc()
499
+ return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
500
+
501
+
502
+ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
503
+ sampler = None
504
+ firstphase_width = 0
505
+ firstphase_height = 0
506
+ firstphase_width_truncated = 0
507
+ firstphase_height_truncated = 0
508
+
509
+ def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
510
+ super().__init__(**kwargs)
511
+ self.enable_hr = enable_hr
512
+ self.scale_latent = scale_latent
513
+ self.denoising_strength = denoising_strength
514
+
515
+ def init(self, all_prompts, all_seeds, all_subseeds):
516
+ if self.enable_hr:
517
+ if state.job_count == -1:
518
+ state.job_count = self.n_iter * 2
519
+ else:
520
+ state.job_count = state.job_count * 2
521
+
522
+ desired_pixel_count = 512 * 512
523
+ actual_pixel_count = self.width * self.height
524
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
525
+
526
+ self.firstphase_width = math.ceil(scale * self.width / 64) * 64
527
+ self.firstphase_height = math.ceil(scale * self.height / 64) * 64
528
+ self.firstphase_width_truncated = int(scale * self.width)
529
+ self.firstphase_height_truncated = int(scale * self.height)
530
+
531
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
532
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
533
+
534
+ if not self.enable_hr:
535
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
536
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
537
+ return samples
538
+
539
+ x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
540
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
541
+
542
+ truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
543
+ truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
544
+
545
+ samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
546
+
547
+ if self.scale_latent:
548
+ samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
549
+ else:
550
+ decoded_samples = decode_first_stage(self.sd_model, samples)
551
+
552
+ if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
553
+ decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
554
+ else:
555
+ lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
556
+
557
+ batch_images = []
558
+ for i, x_sample in enumerate(lowres_samples):
559
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
560
+ x_sample = x_sample.astype(np.uint8)
561
+ image = Image.fromarray(x_sample)
562
+ image = images.resize_image(0, image, self.width, self.height)
563
+ image = np.array(image).astype(np.float32) / 255.0
564
+ image = np.moveaxis(image, 2, 0)
565
+ batch_images.append(image)
566
+
567
+ decoded_samples = torch.from_numpy(np.array(batch_images))
568
+ decoded_samples = decoded_samples.to(shared.device)
569
+ decoded_samples = 2. * decoded_samples - 1.
570
+
571
+ samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
572
+
573
+ shared.state.nextjob()
574
+
575
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
576
+
577
+ noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
578
+
579
+ # GC now before running the next img2img to prevent running out of memory
580
+ x = None
581
+ devices.torch_gc()
582
+
583
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
584
+
585
+ return samples
586
+
587
+
588
+ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
589
+ sampler = None
590
+
591
+ def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
592
+ super().__init__(**kwargs)
593
+
594
+ self.init_images = init_images
595
+ self.resize_mode: int = resize_mode
596
+ self.denoising_strength: float = denoising_strength
597
+ self.init_latent = None
598
+ self.image_mask = mask
599
+ #self.image_unblurred_mask = None
600
+ self.latent_mask = None
601
+ self.mask_for_overlay = None
602
+ self.mask_blur = mask_blur
603
+ self.inpainting_fill = inpainting_fill
604
+ self.inpaint_full_res = inpaint_full_res
605
+ self.inpaint_full_res_padding = inpaint_full_res_padding
606
+ self.inpainting_mask_invert = inpainting_mask_invert
607
+ self.mask = None
608
+ self.nmask = None
609
+
610
+ def init(self, all_prompts, all_seeds, all_subseeds):
611
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
612
+ crop_region = None
613
+
614
+ if self.image_mask is not None:
615
+ self.image_mask = self.image_mask.convert('L')
616
+
617
+ if self.inpainting_mask_invert:
618
+ self.image_mask = ImageOps.invert(self.image_mask)
619
+
620
+ #self.image_unblurred_mask = self.image_mask
621
+
622
+ if self.mask_blur > 0:
623
+ self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
624
+
625
+ if self.inpaint_full_res:
626
+ self.mask_for_overlay = self.image_mask
627
+ mask = self.image_mask.convert('L')
628
+ crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
629
+ crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
630
+ x1, y1, x2, y2 = crop_region
631
+
632
+ mask = mask.crop(crop_region)
633
+ self.image_mask = images.resize_image(2, mask, self.width, self.height)
634
+ self.paste_to = (x1, y1, x2-x1, y2-y1)
635
+ else:
636
+ self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
637
+ np_mask = np.array(self.image_mask)
638
+ np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
639
+ self.mask_for_overlay = Image.fromarray(np_mask)
640
+
641
+ self.overlay_images = []
642
+
643
+ latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
644
+
645
+ add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
646
+ if add_color_corrections:
647
+ self.color_corrections = []
648
+ imgs = []
649
+ for img in self.init_images:
650
+ image = img.convert("RGB")
651
+
652
+ if crop_region is None:
653
+ image = images.resize_image(self.resize_mode, image, self.width, self.height)
654
+
655
+ if self.image_mask is not None:
656
+ image_masked = Image.new('RGBa', (image.width, image.height))
657
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
658
+
659
+ self.overlay_images.append(image_masked.convert('RGBA'))
660
+
661
+ if crop_region is not None:
662
+ image = image.crop(crop_region)
663
+ image = images.resize_image(2, image, self.width, self.height)
664
+
665
+ if self.image_mask is not None:
666
+ if self.inpainting_fill != 1:
667
+ image = masking.fill(image, latent_mask)
668
+
669
+ if add_color_corrections:
670
+ self.color_corrections.append(setup_color_correction(image))
671
+
672
+ image = np.array(image).astype(np.float32) / 255.0
673
+ image = np.moveaxis(image, 2, 0)
674
+
675
+ imgs.append(image)
676
+
677
+ if len(imgs) == 1:
678
+ batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
679
+ if self.overlay_images is not None:
680
+ self.overlay_images = self.overlay_images * self.batch_size
681
+ elif len(imgs) <= self.batch_size:
682
+ self.batch_size = len(imgs)
683
+ batch_images = np.array(imgs)
684
+ else:
685
+ raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
686
+
687
+ image = torch.from_numpy(batch_images)
688
+ image = 2. * image - 1.
689
+ image = image.to(shared.device)
690
+
691
+ self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
692
+
693
+ if self.image_mask is not None:
694
+ init_mask = latent_mask
695
+ latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
696
+ latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
697
+ latmask = latmask[0]
698
+ latmask = np.around(latmask)
699
+ latmask = np.tile(latmask[None], (4, 1, 1))
700
+
701
+ self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
702
+ self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
703
+
704
+ # this needs to be fixed to be done in sample() using actual seeds for batches
705
+ if self.inpainting_fill == 2:
706
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
707
+ elif self.inpainting_fill == 3:
708
+ self.init_latent = self.init_latent * self.mask
709
+
710
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
711
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
712
+
713
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
714
+
715
+ if self.mask is not None:
716
+ samples = samples * self.nmask + self.init_latent * self.mask
717
+
718
+ del x
719
+ devices.torch_gc()
720
+
721
+ return samples
modules/prompt_parser.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import namedtuple
3
+ from typing import List
4
+ import lark
5
+
6
+ # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
7
+ # will be represented with prompt_schedule like this (assuming steps=100):
8
+ # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
9
+ # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
10
+ # [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
11
+ # [75, 'fantasy landscape with a lake and an oak in background masterful']
12
+ # [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
13
+
14
+ schedule_parser = lark.Lark(r"""
15
+ !start: (prompt | /[][():]/+)*
16
+ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
17
+ !emphasized: "(" prompt ")"
18
+ | "(" prompt ":" prompt ")"
19
+ | "[" prompt "]"
20
+ scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
21
+ alternate: "[" prompt ("|" prompt)+ "]"
22
+ WHITESPACE: /\s+/
23
+ plain: /([^\\\[\]():|]|\\.)+/
24
+ %import common.SIGNED_NUMBER -> NUMBER
25
+ """)
26
+
27
+ def get_learned_conditioning_prompt_schedules(prompts, steps):
28
+ """
29
+ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
30
+ >>> g("test")
31
+ [[10, 'test']]
32
+ >>> g("a [b:3]")
33
+ [[3, 'a '], [10, 'a b']]
34
+ >>> g("a [b: 3]")
35
+ [[3, 'a '], [10, 'a b']]
36
+ >>> g("a [[[b]]:2]")
37
+ [[2, 'a '], [10, 'a [[b]]']]
38
+ >>> g("[(a:2):3]")
39
+ [[3, ''], [10, '(a:2)']]
40
+ >>> g("a [b : c : 1] d")
41
+ [[1, 'a b d'], [10, 'a c d']]
42
+ >>> g("a[b:[c:d:2]:1]e")
43
+ [[1, 'abe'], [2, 'ace'], [10, 'ade']]
44
+ >>> g("a [unbalanced")
45
+ [[10, 'a [unbalanced']]
46
+ >>> g("a [b:.5] c")
47
+ [[5, 'a c'], [10, 'a b c']]
48
+ >>> g("a [{b|d{:.5] c") # not handling this right now
49
+ [[5, 'a c'], [10, 'a {b|d{ c']]
50
+ >>> g("((a][:b:c [d:3]")
51
+ [[3, '((a][:b:c '], [10, '((a][:b:c d']]
52
+ """
53
+
54
+ def collect_steps(steps, tree):
55
+ l = [steps]
56
+ class CollectSteps(lark.Visitor):
57
+ def scheduled(self, tree):
58
+ tree.children[-1] = float(tree.children[-1])
59
+ if tree.children[-1] < 1:
60
+ tree.children[-1] *= steps
61
+ tree.children[-1] = min(steps, int(tree.children[-1]))
62
+ l.append(tree.children[-1])
63
+ def alternate(self, tree):
64
+ l.extend(range(1, steps+1))
65
+ CollectSteps().visit(tree)
66
+ return sorted(set(l))
67
+
68
+ def at_step(step, tree):
69
+ class AtStep(lark.Transformer):
70
+ def scheduled(self, args):
71
+ before, after, _, when = args
72
+ yield before or () if step <= when else after
73
+ def alternate(self, args):
74
+ yield next(args[(step - 1)%len(args)])
75
+ def start(self, args):
76
+ def flatten(x):
77
+ if type(x) == str:
78
+ yield x
79
+ else:
80
+ for gen in x:
81
+ yield from flatten(gen)
82
+ return ''.join(flatten(args))
83
+ def plain(self, args):
84
+ yield args[0].value
85
+ def __default__(self, data, children, meta):
86
+ for child in children:
87
+ yield from child
88
+ return AtStep().transform(tree)
89
+
90
+ def get_schedule(prompt):
91
+ try:
92
+ tree = schedule_parser.parse(prompt)
93
+ except lark.exceptions.LarkError as e:
94
+ if 0:
95
+ import traceback
96
+ traceback.print_exc()
97
+ return [[steps, prompt]]
98
+ return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
99
+
100
+ promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
101
+ return [promptdict[prompt] for prompt in prompts]
102
+
103
+
104
+ ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
105
+
106
+
107
+ def get_learned_conditioning(model, prompts, steps):
108
+ """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
109
+ and the sampling step at which this condition is to be replaced by the next one.
110
+
111
+ Input:
112
+ (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
113
+
114
+ Output:
115
+ [
116
+ [
117
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
118
+ ],
119
+ [
120
+ ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
121
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
122
+ ]
123
+ ]
124
+ """
125
+ res = []
126
+
127
+ prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
128
+ cache = {}
129
+
130
+ for prompt, prompt_schedule in zip(prompts, prompt_schedules):
131
+
132
+ cached = cache.get(prompt, None)
133
+ if cached is not None:
134
+ res.append(cached)
135
+ continue
136
+
137
+ texts = [x[1] for x in prompt_schedule]
138
+ conds = model.get_learned_conditioning(texts)
139
+
140
+ cond_schedule = []
141
+ for i, (end_at_step, text) in enumerate(prompt_schedule):
142
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
143
+
144
+ cache[prompt] = cond_schedule
145
+ res.append(cond_schedule)
146
+
147
+ return res
148
+
149
+
150
+ re_AND = re.compile(r"\bAND\b")
151
+ re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
152
+
153
+ def get_multicond_prompt_list(prompts):
154
+ res_indexes = []
155
+
156
+ prompt_flat_list = []
157
+ prompt_indexes = {}
158
+
159
+ for prompt in prompts:
160
+ subprompts = re_AND.split(prompt)
161
+
162
+ indexes = []
163
+ for subprompt in subprompts:
164
+ match = re_weight.search(subprompt)
165
+
166
+ text, weight = match.groups() if match is not None else (subprompt, 1.0)
167
+
168
+ weight = float(weight) if weight is not None else 1.0
169
+
170
+ index = prompt_indexes.get(text, None)
171
+ if index is None:
172
+ index = len(prompt_flat_list)
173
+ prompt_flat_list.append(text)
174
+ prompt_indexes[text] = index
175
+
176
+ indexes.append((index, weight))
177
+
178
+ res_indexes.append(indexes)
179
+
180
+ return res_indexes, prompt_flat_list, prompt_indexes
181
+
182
+
183
+ class ComposableScheduledPromptConditioning:
184
+ def __init__(self, schedules, weight=1.0):
185
+ self.schedules: List[ScheduledPromptConditioning] = schedules
186
+ self.weight: float = weight
187
+
188
+
189
+ class MulticondLearnedConditioning:
190
+ def __init__(self, shape, batch):
191
+ self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
192
+ self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
193
+
194
+ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
195
+ """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
196
+ For each prompt, the list is obtained by splitting the prompt using the AND separator.
197
+
198
+ https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
199
+ """
200
+
201
+ res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
202
+
203
+ learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
204
+
205
+ res = []
206
+ for indexes in res_indexes:
207
+ res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
208
+
209
+ return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
210
+
211
+
212
+ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
213
+ param = c[0][0].cond
214
+ res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
215
+ for i, cond_schedule in enumerate(c):
216
+ target_index = 0
217
+ for current, (end_at, cond) in enumerate(cond_schedule):
218
+ if current_step <= end_at:
219
+ target_index = current
220
+ break
221
+ res[i] = cond_schedule[target_index].cond
222
+
223
+ return res
224
+
225
+
226
+ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
227
+ param = c.batch[0][0].schedules[0].cond
228
+
229
+ tensors = []
230
+ conds_list = []
231
+
232
+ for batch_no, composable_prompts in enumerate(c.batch):
233
+ conds_for_batch = []
234
+
235
+ for cond_index, composable_prompt in enumerate(composable_prompts):
236
+ target_index = 0
237
+ for current, (end_at, cond) in enumerate(composable_prompt.schedules):
238
+ if current_step <= end_at:
239
+ target_index = current
240
+ break
241
+
242
+ conds_for_batch.append((len(tensors), composable_prompt.weight))
243
+ tensors.append(composable_prompt.schedules[target_index].cond)
244
+
245
+ conds_list.append(conds_for_batch)
246
+
247
+ # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
248
+ # and won't be able to torch.stack them. So this fixes that.
249
+ token_count = max([x.shape[0] for x in tensors])
250
+ for i in range(len(tensors)):
251
+ if tensors[i].shape[0] != token_count:
252
+ last_vector = tensors[i][-1:]
253
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
254
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
255
+
256
+ return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
257
+
258
+
259
+ re_attention = re.compile(r"""
260
+ \\\(|
261
+ \\\)|
262
+ \\\[|
263
+ \\]|
264
+ \\\\|
265
+ \\|
266
+ \(|
267
+ \[|
268
+ :([+-]?[.\d]+)\)|
269
+ \)|
270
+ ]|
271
+ [^\\()\[\]:]+|
272
+ :
273
+ """, re.X)
274
+
275
+
276
+ def parse_prompt_attention(text):
277
+ """
278
+ Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
279
+ Accepted tokens are:
280
+ (abc) - increases attention to abc by a multiplier of 1.1
281
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
282
+ [abc] - decreases attention to abc by a multiplier of 1.1
283
+ \( - literal character '('
284
+ \[ - literal character '['
285
+ \) - literal character ')'
286
+ \] - literal character ']'
287
+ \\ - literal character '\'
288
+ anything else - just text
289
+
290
+ >>> parse_prompt_attention('normal text')
291
+ [['normal text', 1.0]]
292
+ >>> parse_prompt_attention('an (important) word')
293
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
294
+ >>> parse_prompt_attention('(unbalanced')
295
+ [['unbalanced', 1.1]]
296
+ >>> parse_prompt_attention('\(literal\]')
297
+ [['(literal]', 1.0]]
298
+ >>> parse_prompt_attention('(unnecessary)(parens)')
299
+ [['unnecessaryparens', 1.1]]
300
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
301
+ [['a ', 1.0],
302
+ ['house', 1.5730000000000004],
303
+ [' ', 1.1],
304
+ ['on', 1.0],
305
+ [' a ', 1.1],
306
+ ['hill', 0.55],
307
+ [', sun, ', 1.1],
308
+ ['sky', 1.4641000000000006],
309
+ ['.', 1.1]]
310
+ """
311
+
312
+ res = []
313
+ round_brackets = []
314
+ square_brackets = []
315
+
316
+ round_bracket_multiplier = 1.1
317
+ square_bracket_multiplier = 1 / 1.1
318
+
319
+ def multiply_range(start_position, multiplier):
320
+ for p in range(start_position, len(res)):
321
+ res[p][1] *= multiplier
322
+
323
+ for m in re_attention.finditer(text):
324
+ text = m.group(0)
325
+ weight = m.group(1)
326
+
327
+ if text.startswith('\\'):
328
+ res.append([text[1:], 1.0])
329
+ elif text == '(':
330
+ round_brackets.append(len(res))
331
+ elif text == '[':
332
+ square_brackets.append(len(res))
333
+ elif weight is not None and len(round_brackets) > 0:
334
+ multiply_range(round_brackets.pop(), float(weight))
335
+ elif text == ')' and len(round_brackets) > 0:
336
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
337
+ elif text == ']' and len(square_brackets) > 0:
338
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
339
+ else:
340
+ res.append([text, 1.0])
341
+
342
+ for pos in round_brackets:
343
+ multiply_range(pos, round_bracket_multiplier)
344
+
345
+ for pos in square_brackets:
346
+ multiply_range(pos, square_bracket_multiplier)
347
+
348
+ if len(res) == 0:
349
+ res = [["", 1.0]]
350
+
351
+ # merge runs of identical weights
352
+ i = 0
353
+ while i + 1 < len(res):
354
+ if res[i][1] == res[i + 1][1]:
355
+ res[i][0] += res[i + 1][0]
356
+ res.pop(i + 1)
357
+ else:
358
+ i += 1
359
+
360
+ return res
361
+
362
+ if __name__ == "__main__":
363
+ import doctest
364
+ doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
365
+ else:
366
+ import torch # doctest faster
modules/realesrgan_model.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ from basicsr.utils.download_util import load_file_from_url
8
+ from realesrgan import RealESRGANer
9
+
10
+ from modules.upscaler import Upscaler, UpscalerData
11
+ from modules.shared import cmd_opts, opts
12
+
13
+
14
+ class UpscalerRealESRGAN(Upscaler):
15
+ def __init__(self, path):
16
+ self.name = "RealESRGAN"
17
+ self.user_path = path
18
+ super().__init__()
19
+ try:
20
+ from basicsr.archs.rrdbnet_arch import RRDBNet
21
+ from realesrgan import RealESRGANer
22
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
23
+ self.enable = True
24
+ self.scalers = []
25
+ scalers = self.load_models(path)
26
+ for scaler in scalers:
27
+ if scaler.name in opts.realesrgan_enabled_models:
28
+ self.scalers.append(scaler)
29
+
30
+ except Exception:
31
+ print("Error importing Real-ESRGAN:", file=sys.stderr)
32
+ print(traceback.format_exc(), file=sys.stderr)
33
+ self.enable = False
34
+ self.scalers = []
35
+
36
+ def do_upscale(self, img, path):
37
+ if not self.enable:
38
+ return img
39
+
40
+ info = self.load_model(path)
41
+ if not os.path.exists(info.data_path):
42
+ print("Unable to load RealESRGAN model: %s" % info.name)
43
+ return img
44
+
45
+ upsampler = RealESRGANer(
46
+ scale=info.scale,
47
+ model_path=info.data_path,
48
+ model=info.model(),
49
+ half=not cmd_opts.no_half,
50
+ tile=opts.ESRGAN_tile,
51
+ tile_pad=opts.ESRGAN_tile_overlap,
52
+ )
53
+
54
+ upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
55
+
56
+ image = Image.fromarray(upsampled)
57
+ return image
58
+
59
+ def load_model(self, path):
60
+ try:
61
+ info = None
62
+ for scaler in self.scalers:
63
+ if scaler.data_path == path:
64
+ info = scaler
65
+
66
+ if info is None:
67
+ print(f"Unable to find model info: {path}")
68
+ return None
69
+
70
+ model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
71
+ info.data_path = model_file
72
+ return info
73
+ except Exception as e:
74
+ print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
75
+ print(traceback.format_exc(), file=sys.stderr)
76
+ return None
77
+
78
+ def load_models(self, _):
79
+ return get_realesrgan_models(self)
80
+
81
+
82
+ def get_realesrgan_models(scaler):
83
+ try:
84
+ from basicsr.archs.rrdbnet_arch import RRDBNet
85
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
86
+ models = [
87
+ UpscalerData(
88
+ name="R-ESRGAN General 4xV3",
89
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
90
+ scale=4,
91
+ upscaler=scaler,
92
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
93
+ ),
94
+ UpscalerData(
95
+ name="R-ESRGAN General WDN 4xV3",
96
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
97
+ scale=4,
98
+ upscaler=scaler,
99
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
100
+ ),
101
+ UpscalerData(
102
+ name="R-ESRGAN AnimeVideo",
103
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
104
+ scale=4,
105
+ upscaler=scaler,
106
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
107
+ ),
108
+ UpscalerData(
109
+ name="R-ESRGAN 4x+",
110
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
111
+ scale=4,
112
+ upscaler=scaler,
113
+ model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
114
+ ),
115
+ UpscalerData(
116
+ name="R-ESRGAN 4x+ Anime6B",
117
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
118
+ scale=4,
119
+ upscaler=scaler,
120
+ model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
121
+ ),
122
+ UpscalerData(
123
+ name="R-ESRGAN 2x+",
124
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
125
+ scale=2,
126
+ upscaler=scaler,
127
+ model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
128
+ ),
129
+ ]
130
+ return models
131
+ except Exception as e:
132
+ print("Error making Real-ESRGAN models list:", file=sys.stderr)
133
+ print(traceback.format_exc(), file=sys.stderr)
modules/safe.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this code is adapted from the script contributed by anon from /h/
2
+
3
+ import io
4
+ import pickle
5
+ import collections
6
+ import sys
7
+ import traceback
8
+
9
+ import torch
10
+ import numpy
11
+ import _codecs
12
+ import zipfile
13
+ import re
14
+
15
+
16
+ # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
17
+ TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
18
+
19
+
20
+ def encode(*args):
21
+ out = _codecs.encode(*args)
22
+ return out
23
+
24
+
25
+ class RestrictedUnpickler(pickle.Unpickler):
26
+ def persistent_load(self, saved_id):
27
+ assert saved_id[0] == 'storage'
28
+ return TypedStorage()
29
+
30
+ def find_class(self, module, name):
31
+ if module == 'collections' and name == 'OrderedDict':
32
+ return getattr(collections, name)
33
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
34
+ return getattr(torch._utils, name)
35
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
36
+ return getattr(torch, name)
37
+ if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
38
+ return getattr(torch.nn.modules.container, name)
39
+ if module == 'numpy.core.multiarray' and name == 'scalar':
40
+ return numpy.core.multiarray.scalar
41
+ if module == 'numpy' and name == 'dtype':
42
+ return numpy.dtype
43
+ if module == '_codecs' and name == 'encode':
44
+ return encode
45
+ if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
46
+ import pytorch_lightning.callbacks
47
+ return pytorch_lightning.callbacks.model_checkpoint
48
+ if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
49
+ import pytorch_lightning.callbacks.model_checkpoint
50
+ return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
51
+ if module == "__builtin__" and name == 'set':
52
+ return set
53
+
54
+ # Forbid everything else.
55
+ raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
56
+
57
+
58
+ allowed_zip_names = ["archive/data.pkl", "archive/version"]
59
+ allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
60
+
61
+
62
+ def check_zip_filenames(filename, names):
63
+ for name in names:
64
+ if name in allowed_zip_names:
65
+ continue
66
+ if allowed_zip_names_re.match(name):
67
+ continue
68
+
69
+ raise Exception(f"bad file inside {filename}: {name}")
70
+
71
+
72
+ def check_pt(filename):
73
+ try:
74
+
75
+ # new pytorch format is a zip file
76
+ with zipfile.ZipFile(filename) as z:
77
+ check_zip_filenames(filename, z.namelist())
78
+
79
+ with z.open('archive/data.pkl') as file:
80
+ unpickler = RestrictedUnpickler(file)
81
+ unpickler.load()
82
+
83
+ except zipfile.BadZipfile:
84
+
85
+ # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
86
+ with open(filename, "rb") as file:
87
+ unpickler = RestrictedUnpickler(file)
88
+ for i in range(5):
89
+ unpickler.load()
90
+
91
+
92
+ def load(filename, *args, **kwargs):
93
+ from modules import shared
94
+
95
+ try:
96
+ if not shared.cmd_opts.disable_safe_unpickle:
97
+ check_pt(filename)
98
+
99
+ except Exception:
100
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
101
+ print(traceback.format_exc(), file=sys.stderr)
102
+ print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
103
+ print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr)
104
+ return None
105
+
106
+ return unsafe_torch_load(filename, *args, **kwargs)
107
+
108
+
109
+ unsafe_torch_load = torch.load
110
+ torch.load = load
modules/safety.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
3
+ from transformers import AutoFeatureExtractor
4
+ from PIL import Image
5
+
6
+ import modules.shared as shared
7
+
8
+ safety_model_id = "CompVis/stable-diffusion-safety-checker"
9
+ safety_feature_extractor = None
10
+ safety_checker = None
11
+
12
+ def numpy_to_pil(images):
13
+ """
14
+ Convert a numpy image or a batch of images to a PIL image.
15
+ """
16
+ if images.ndim == 3:
17
+ images = images[None, ...]
18
+ images = (images * 255).round().astype("uint8")
19
+ pil_images = [Image.fromarray(image) for image in images]
20
+
21
+ return pil_images
22
+
23
+ # check and replace nsfw content
24
+ def check_safety(x_image):
25
+ global safety_feature_extractor, safety_checker
26
+
27
+ if safety_feature_extractor is None:
28
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
29
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
30
+
31
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
32
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
33
+
34
+ return x_checked_image, has_nsfw_concept
35
+
36
+
37
+ def censor_batch(x):
38
+ x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
39
+ x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
40
+ x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
41
+
42
+ return x
modules/scripts.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import modules.ui as ui
6
+ import gradio as gr
7
+
8
+ from modules.processing import StableDiffusionProcessing
9
+ from modules import shared
10
+
11
+ class Script:
12
+ filename = None
13
+ args_from = None
14
+ args_to = None
15
+
16
+ # The title of the script. This is what will be displayed in the dropdown menu.
17
+ def title(self):
18
+ raise NotImplementedError()
19
+
20
+ # How the script is displayed in the UI. See https://gradio.app/docs/#components
21
+ # for the different UI components you can use and how to create them.
22
+ # Most UI components can return a value, such as a boolean for a checkbox.
23
+ # The returned values are passed to the run method as parameters.
24
+ def ui(self, is_img2img):
25
+ pass
26
+
27
+ # Determines when the script should be shown in the dropdown menu via the
28
+ # returned value. As an example:
29
+ # is_img2img is True if the current tab is img2img, and False if it is txt2img.
30
+ # Thus, return is_img2img to only show the script on the img2img tab.
31
+ def show(self, is_img2img):
32
+ return True
33
+
34
+ # This is where the additional processing is implemented. The parameters include
35
+ # self, the model object "p" (a StableDiffusionProcessing class, see
36
+ # processing.py), and the parameters returned by the ui method.
37
+ # Custom functions can be defined here, and additional libraries can be imported
38
+ # to be used in processing. The return value should be a Processed object, which is
39
+ # what is returned by the process_images method.
40
+ def run(self, *args):
41
+ raise NotImplementedError()
42
+
43
+ # The description method is currently unused.
44
+ # To add a description that appears when hovering over the title, amend the "titles"
45
+ # dict in script.js to include the script title (returned by title) as a key, and
46
+ # your description as the value.
47
+ def describe(self):
48
+ return ""
49
+
50
+
51
+ scripts_data = []
52
+
53
+
54
+ def load_scripts(basedir):
55
+ if not os.path.exists(basedir):
56
+ return
57
+
58
+ for filename in sorted(os.listdir(basedir)):
59
+ path = os.path.join(basedir, filename)
60
+
61
+ if not os.path.isfile(path):
62
+ continue
63
+
64
+ try:
65
+ with open(path, "r", encoding="utf8") as file:
66
+ text = file.read()
67
+
68
+ from types import ModuleType
69
+ compiled = compile(text, path, 'exec')
70
+ module = ModuleType(filename)
71
+ exec(compiled, module.__dict__)
72
+
73
+ for key, script_class in module.__dict__.items():
74
+ if type(script_class) == type and issubclass(script_class, Script):
75
+ scripts_data.append((script_class, path))
76
+
77
+ except Exception:
78
+ print(f"Error loading script: {filename}", file=sys.stderr)
79
+ print(traceback.format_exc(), file=sys.stderr)
80
+
81
+
82
+ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
83
+ try:
84
+ res = func(*args, **kwargs)
85
+ return res
86
+ except Exception:
87
+ print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
88
+ print(traceback.format_exc(), file=sys.stderr)
89
+
90
+ return default
91
+
92
+
93
+ class ScriptRunner:
94
+ def __init__(self):
95
+ self.scripts = []
96
+
97
+ def setup_ui(self, is_img2img):
98
+ for script_class, path in scripts_data:
99
+ script = script_class()
100
+ script.filename = path
101
+
102
+ if not script.show(is_img2img):
103
+ continue
104
+
105
+ self.scripts.append(script)
106
+
107
+ titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
108
+
109
+ dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index")
110
+ inputs = [dropdown]
111
+
112
+ for script in self.scripts:
113
+ script.args_from = len(inputs)
114
+ script.args_to = len(inputs)
115
+
116
+ controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
117
+
118
+ if controls is None:
119
+ continue
120
+
121
+ for control in controls:
122
+ control.custom_script_source = os.path.basename(script.filename)
123
+ control.visible = False
124
+
125
+ inputs += controls
126
+ script.args_to = len(inputs)
127
+
128
+ def select_script(script_index):
129
+ if 0 < script_index <= len(self.scripts):
130
+ script = self.scripts[script_index-1]
131
+ args_from = script.args_from
132
+ args_to = script.args_to
133
+ else:
134
+ args_from = 0
135
+ args_to = 0
136
+
137
+ return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
138
+
139
+ dropdown.change(
140
+ fn=select_script,
141
+ inputs=[dropdown],
142
+ outputs=inputs
143
+ )
144
+
145
+ return inputs
146
+
147
+ def run(self, p: StableDiffusionProcessing, *args):
148
+ script_index = args[0]
149
+
150
+ if script_index == 0:
151
+ return None
152
+
153
+ script = self.scripts[script_index-1]
154
+
155
+ if script is None:
156
+ return None
157
+
158
+ script_args = args[script.args_from:script.args_to]
159
+ processed = script.run(p, *script_args)
160
+
161
+ shared.total_tqdm.clear()
162
+
163
+ return processed
164
+
165
+ def reload_sources(self):
166
+ for si, script in list(enumerate(self.scripts)):
167
+ with open(script.filename, "r", encoding="utf8") as file:
168
+ args_from = script.args_from
169
+ args_to = script.args_to
170
+ filename = script.filename
171
+ text = file.read()
172
+
173
+ from types import ModuleType
174
+
175
+ compiled = compile(text, filename, 'exec')
176
+ module = ModuleType(script.filename)
177
+ exec(compiled, module.__dict__)
178
+
179
+ for key, script_class in module.__dict__.items():
180
+ if type(script_class) == type and issubclass(script_class, Script):
181
+ self.scripts[si] = script_class()
182
+ self.scripts[si].filename = filename
183
+ self.scripts[si].args_from = args_from
184
+ self.scripts[si].args_to = args_to
185
+
186
+ scripts_txt2img = ScriptRunner()
187
+ scripts_img2img = ScriptRunner()
188
+
189
+ def reload_script_body_only():
190
+ scripts_txt2img.reload_sources()
191
+ scripts_img2img.reload_sources()
192
+
193
+
194
+ def reload_scripts(basedir):
195
+ global scripts_txt2img, scripts_img2img
196
+
197
+ scripts_data.clear()
198
+ load_scripts(basedir)
199
+
200
+ scripts_txt2img = ScriptRunner()
201
+ scripts_img2img = ScriptRunner()