Spaces:
Runtime error
Runtime error
aclicheroux
commited on
Commit
•
e0c66e4
1
Parent(s):
d9f9915
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- CODEOWNERS +1 -0
- app.py +137 -0
- artists.csv +0 -0
- environment-wsl2.yaml +11 -0
- javascript/aspectRatioOverlay.js +119 -0
- javascript/contextMenus.js +177 -0
- javascript/dragdrop.js +86 -0
- javascript/edit-attention.js +45 -0
- javascript/hints.js +121 -0
- javascript/imageMaskFix.js +45 -0
- javascript/imageviewer.js +236 -0
- javascript/notification.js +49 -0
- javascript/progressbar.js +76 -0
- javascript/textualInversion.js +8 -0
- javascript/ui.js +234 -0
- launch.py +169 -0
- modules/artists.py +25 -0
- modules/bsrgan_model.py +76 -0
- modules/bsrgan_model_arch.py +102 -0
- modules/codeformer/codeformer_arch.py +278 -0
- modules/codeformer/vqgan_arch.py +437 -0
- modules/codeformer_model.py +140 -0
- modules/deepbooru.py +173 -0
- modules/devices.py +72 -0
- modules/errors.py +10 -0
- modules/esrgan_model.py +158 -0
- modules/esrgan_model_arch.py +80 -0
- modules/extras.py +222 -0
- modules/face_restoration.py +19 -0
- modules/generation_parameters_copypaste.py +101 -0
- modules/gfpgan_model.py +115 -0
- modules/hypernetworks/hypernetwork.py +314 -0
- modules/hypernetworks/ui.py +47 -0
- modules/images.py +465 -0
- modules/img2img.py +137 -0
- modules/interrogate.py +171 -0
- modules/ldsr_model.py +54 -0
- modules/ldsr_model_arch.py +222 -0
- modules/lowvram.py +82 -0
- modules/masking.py +99 -0
- modules/memmon.py +85 -0
- modules/modelloader.py +153 -0
- modules/ngrok.py +15 -0
- modules/paths.py +40 -0
- modules/processing.py +721 -0
- modules/prompt_parser.py +366 -0
- modules/realesrgan_model.py +133 -0
- modules/safe.py +110 -0
- modules/safety.py +42 -0
- modules/scripts.py +201 -0
CODEOWNERS
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
* @AUTOMATIC1111
|
app.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import threading
|
3 |
+
import time
|
4 |
+
import importlib
|
5 |
+
import signal
|
6 |
+
import threading
|
7 |
+
|
8 |
+
from fastapi.middleware.gzip import GZipMiddleware
|
9 |
+
|
10 |
+
from modules.paths import script_path
|
11 |
+
|
12 |
+
from modules import devices, sd_samplers
|
13 |
+
import modules.codeformer_model as codeformer
|
14 |
+
import modules.extras
|
15 |
+
import modules.face_restoration
|
16 |
+
import modules.gfpgan_model as gfpgan
|
17 |
+
import modules.img2img
|
18 |
+
|
19 |
+
import modules.lowvram
|
20 |
+
import modules.paths
|
21 |
+
import modules.scripts
|
22 |
+
import modules.sd_hijack
|
23 |
+
import modules.sd_models
|
24 |
+
import modules.shared as shared
|
25 |
+
import modules.txt2img
|
26 |
+
|
27 |
+
import modules.ui
|
28 |
+
from modules import devices
|
29 |
+
from modules import modelloader
|
30 |
+
from modules.paths import script_path
|
31 |
+
from modules.shared import cmd_opts
|
32 |
+
import modules.hypernetworks.hypernetwork
|
33 |
+
|
34 |
+
|
35 |
+
queue_lock = threading.Lock()
|
36 |
+
|
37 |
+
|
38 |
+
def wrap_queued_call(func):
|
39 |
+
def f(*args, **kwargs):
|
40 |
+
with queue_lock:
|
41 |
+
res = func(*args, **kwargs)
|
42 |
+
|
43 |
+
return res
|
44 |
+
|
45 |
+
return f
|
46 |
+
|
47 |
+
|
48 |
+
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
49 |
+
def f(*args, **kwargs):
|
50 |
+
devices.torch_gc()
|
51 |
+
|
52 |
+
shared.state.sampling_step = 0
|
53 |
+
shared.state.job_count = -1
|
54 |
+
shared.state.job_no = 0
|
55 |
+
shared.state.job_timestamp = shared.state.get_job_timestamp()
|
56 |
+
shared.state.current_latent = None
|
57 |
+
shared.state.current_image = None
|
58 |
+
shared.state.current_image_sampling_step = 0
|
59 |
+
shared.state.skipped = False
|
60 |
+
shared.state.interrupted = False
|
61 |
+
shared.state.textinfo = None
|
62 |
+
|
63 |
+
with queue_lock:
|
64 |
+
res = func(*args, **kwargs)
|
65 |
+
|
66 |
+
shared.state.job = ""
|
67 |
+
shared.state.job_count = 0
|
68 |
+
|
69 |
+
devices.torch_gc()
|
70 |
+
|
71 |
+
return res
|
72 |
+
|
73 |
+
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
74 |
+
|
75 |
+
def initialize():
|
76 |
+
modelloader.cleanup_models()
|
77 |
+
modules.sd_models.setup_model()
|
78 |
+
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
79 |
+
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
80 |
+
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
81 |
+
modelloader.load_upscalers()
|
82 |
+
|
83 |
+
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
84 |
+
|
85 |
+
shared.sd_model = modules.sd_models.load_model()
|
86 |
+
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
87 |
+
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
88 |
+
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
89 |
+
|
90 |
+
|
91 |
+
def webui():
|
92 |
+
initialize()
|
93 |
+
|
94 |
+
# make the program just exit at ctrl+c without waiting for anything
|
95 |
+
def sigint_handler(sig, frame):
|
96 |
+
print(f'Interrupted with signal {sig} in {frame}')
|
97 |
+
os._exit(0)
|
98 |
+
|
99 |
+
signal.signal(signal.SIGINT, sigint_handler)
|
100 |
+
|
101 |
+
while 1:
|
102 |
+
|
103 |
+
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
104 |
+
|
105 |
+
app, local_url, share_url = demo.launch(
|
106 |
+
share=cmd_opts.share,
|
107 |
+
server_name="0.0.0.0" if cmd_opts.listen else None,
|
108 |
+
server_port=cmd_opts.port,
|
109 |
+
debug=cmd_opts.gradio_debug,
|
110 |
+
auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
|
111 |
+
inbrowser=cmd_opts.autolaunch,
|
112 |
+
prevent_thread_lock=True
|
113 |
+
)
|
114 |
+
|
115 |
+
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
116 |
+
|
117 |
+
while 1:
|
118 |
+
time.sleep(0.5)
|
119 |
+
if getattr(demo, 'do_restart', False):
|
120 |
+
time.sleep(0.5)
|
121 |
+
demo.close()
|
122 |
+
time.sleep(0.5)
|
123 |
+
break
|
124 |
+
|
125 |
+
sd_samplers.set_samplers()
|
126 |
+
|
127 |
+
print('Reloading Custom Scripts')
|
128 |
+
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
|
129 |
+
print('Reloading modules: modules.ui')
|
130 |
+
importlib.reload(modules.ui)
|
131 |
+
print('Refreshing Model List')
|
132 |
+
modules.sd_models.list_models()
|
133 |
+
print('Restarting Gradio')
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
webui()
|
artists.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
environment-wsl2.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: automatic
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.10
|
7 |
+
- pip=22.2.2
|
8 |
+
- cudatoolkit=11.3
|
9 |
+
- pytorch=1.12.1
|
10 |
+
- torchvision=0.13.1
|
11 |
+
- numpy=1.23.1
|
javascript/aspectRatioOverlay.js
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
let currentWidth = null;
|
3 |
+
let currentHeight = null;
|
4 |
+
let arFrameTimeout = setTimeout(function(){},0);
|
5 |
+
|
6 |
+
function dimensionChange(e,dimname){
|
7 |
+
|
8 |
+
if(dimname == 'Width'){
|
9 |
+
currentWidth = e.target.value*1.0
|
10 |
+
}
|
11 |
+
if(dimname == 'Height'){
|
12 |
+
currentHeight = e.target.value*1.0
|
13 |
+
}
|
14 |
+
|
15 |
+
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
|
16 |
+
|
17 |
+
if(!inImg2img){
|
18 |
+
return;
|
19 |
+
}
|
20 |
+
|
21 |
+
var img2imgMode = gradioApp().querySelector('#mode_img2img.tabs > div > button.rounded-t-lg.border-gray-200')
|
22 |
+
if(img2imgMode){
|
23 |
+
img2imgMode=img2imgMode.innerText
|
24 |
+
}else{
|
25 |
+
return;
|
26 |
+
}
|
27 |
+
|
28 |
+
var redrawImage = gradioApp().querySelector('div[data-testid=image] img');
|
29 |
+
var inpaintImage = gradioApp().querySelector('#img2maskimg div[data-testid=image] img')
|
30 |
+
|
31 |
+
var targetElement = null;
|
32 |
+
|
33 |
+
if(img2imgMode=='img2img' && redrawImage){
|
34 |
+
targetElement = redrawImage;
|
35 |
+
}else if(img2imgMode=='Inpaint' && inpaintImage){
|
36 |
+
targetElement = inpaintImage;
|
37 |
+
}
|
38 |
+
|
39 |
+
if(targetElement){
|
40 |
+
|
41 |
+
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
42 |
+
if(!arPreviewRect){
|
43 |
+
arPreviewRect = document.createElement('div')
|
44 |
+
arPreviewRect.id = "imageARPreview";
|
45 |
+
gradioApp().getRootNode().appendChild(arPreviewRect)
|
46 |
+
}
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
var viewportOffset = targetElement.getBoundingClientRect();
|
51 |
+
|
52 |
+
viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
|
53 |
+
|
54 |
+
scaledx = targetElement.naturalWidth*viewportscale
|
55 |
+
scaledy = targetElement.naturalHeight*viewportscale
|
56 |
+
|
57 |
+
cleintRectTop = (viewportOffset.top+window.scrollY)
|
58 |
+
cleintRectLeft = (viewportOffset.left+window.scrollX)
|
59 |
+
cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
|
60 |
+
cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
|
61 |
+
|
62 |
+
viewRectTop = cleintRectCentreY-(scaledy/2)
|
63 |
+
viewRectLeft = cleintRectCentreX-(scaledx/2)
|
64 |
+
arRectWidth = scaledx
|
65 |
+
arRectHeight = scaledy
|
66 |
+
|
67 |
+
arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
|
68 |
+
arscaledx = currentWidth*arscale
|
69 |
+
arscaledy = currentHeight*arscale
|
70 |
+
|
71 |
+
arRectTop = cleintRectCentreY-(arscaledy/2)
|
72 |
+
arRectLeft = cleintRectCentreX-(arscaledx/2)
|
73 |
+
arRectWidth = arscaledx
|
74 |
+
arRectHeight = arscaledy
|
75 |
+
|
76 |
+
arPreviewRect.style.top = arRectTop+'px';
|
77 |
+
arPreviewRect.style.left = arRectLeft+'px';
|
78 |
+
arPreviewRect.style.width = arRectWidth+'px';
|
79 |
+
arPreviewRect.style.height = arRectHeight+'px';
|
80 |
+
|
81 |
+
clearTimeout(arFrameTimeout);
|
82 |
+
arFrameTimeout = setTimeout(function(){
|
83 |
+
arPreviewRect.style.display = 'none';
|
84 |
+
},2000);
|
85 |
+
|
86 |
+
arPreviewRect.style.display = 'block';
|
87 |
+
|
88 |
+
}
|
89 |
+
|
90 |
+
}
|
91 |
+
|
92 |
+
|
93 |
+
onUiUpdate(function(){
|
94 |
+
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
95 |
+
if(arPreviewRect){
|
96 |
+
arPreviewRect.style.display = 'none';
|
97 |
+
}
|
98 |
+
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
|
99 |
+
if(inImg2img){
|
100 |
+
let inputs = gradioApp().querySelectorAll('input');
|
101 |
+
inputs.forEach(function(e){
|
102 |
+
let parentLabel = e.parentElement.querySelector('label')
|
103 |
+
if(parentLabel && parentLabel.innerText){
|
104 |
+
if(!e.classList.contains('scrollwatch')){
|
105 |
+
if(parentLabel.innerText == 'Width' || parentLabel.innerText == 'Height'){
|
106 |
+
e.addEventListener('input', function(e){dimensionChange(e,parentLabel.innerText)} )
|
107 |
+
e.classList.add('scrollwatch')
|
108 |
+
}
|
109 |
+
if(parentLabel.innerText == 'Width'){
|
110 |
+
currentWidth = e.value*1.0
|
111 |
+
}
|
112 |
+
if(parentLabel.innerText == 'Height'){
|
113 |
+
currentHeight = e.value*1.0
|
114 |
+
}
|
115 |
+
}
|
116 |
+
}
|
117 |
+
})
|
118 |
+
}
|
119 |
+
});
|
javascript/contextMenus.js
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
contextMenuInit = function(){
|
3 |
+
let eventListenerApplied=false;
|
4 |
+
let menuSpecs = new Map();
|
5 |
+
|
6 |
+
const uid = function(){
|
7 |
+
return Date.now().toString(36) + Math.random().toString(36).substr(2);
|
8 |
+
}
|
9 |
+
|
10 |
+
function showContextMenu(event,element,menuEntries){
|
11 |
+
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
|
12 |
+
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
|
13 |
+
|
14 |
+
let oldMenu = gradioApp().querySelector('#context-menu')
|
15 |
+
if(oldMenu){
|
16 |
+
oldMenu.remove()
|
17 |
+
}
|
18 |
+
|
19 |
+
let tabButton = uiCurrentTab
|
20 |
+
let baseStyle = window.getComputedStyle(tabButton)
|
21 |
+
|
22 |
+
const contextMenu = document.createElement('nav')
|
23 |
+
contextMenu.id = "context-menu"
|
24 |
+
contextMenu.style.background = baseStyle.background
|
25 |
+
contextMenu.style.color = baseStyle.color
|
26 |
+
contextMenu.style.fontFamily = baseStyle.fontFamily
|
27 |
+
contextMenu.style.top = posy+'px'
|
28 |
+
contextMenu.style.left = posx+'px'
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
const contextMenuList = document.createElement('ul')
|
33 |
+
contextMenuList.className = 'context-menu-items';
|
34 |
+
contextMenu.append(contextMenuList);
|
35 |
+
|
36 |
+
menuEntries.forEach(function(entry){
|
37 |
+
let contextMenuEntry = document.createElement('a')
|
38 |
+
contextMenuEntry.innerHTML = entry['name']
|
39 |
+
contextMenuEntry.addEventListener("click", function(e) {
|
40 |
+
entry['func']();
|
41 |
+
})
|
42 |
+
contextMenuList.append(contextMenuEntry);
|
43 |
+
|
44 |
+
})
|
45 |
+
|
46 |
+
gradioApp().getRootNode().appendChild(contextMenu)
|
47 |
+
|
48 |
+
let menuWidth = contextMenu.offsetWidth + 4;
|
49 |
+
let menuHeight = contextMenu.offsetHeight + 4;
|
50 |
+
|
51 |
+
let windowWidth = window.innerWidth;
|
52 |
+
let windowHeight = window.innerHeight;
|
53 |
+
|
54 |
+
if ( (windowWidth - posx) < menuWidth ) {
|
55 |
+
contextMenu.style.left = windowWidth - menuWidth + "px";
|
56 |
+
}
|
57 |
+
|
58 |
+
if ( (windowHeight - posy) < menuHeight ) {
|
59 |
+
contextMenu.style.top = windowHeight - menuHeight + "px";
|
60 |
+
}
|
61 |
+
|
62 |
+
}
|
63 |
+
|
64 |
+
function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){
|
65 |
+
|
66 |
+
currentItems = menuSpecs.get(targetEmementSelector)
|
67 |
+
|
68 |
+
if(!currentItems){
|
69 |
+
currentItems = []
|
70 |
+
menuSpecs.set(targetEmementSelector,currentItems);
|
71 |
+
}
|
72 |
+
let newItem = {'id':targetEmementSelector+'_'+uid(),
|
73 |
+
'name':entryName,
|
74 |
+
'func':entryFunction,
|
75 |
+
'isNew':true}
|
76 |
+
|
77 |
+
currentItems.push(newItem)
|
78 |
+
return newItem['id']
|
79 |
+
}
|
80 |
+
|
81 |
+
function removeContextMenuOption(uid){
|
82 |
+
menuSpecs.forEach(function(v,k) {
|
83 |
+
let index = -1
|
84 |
+
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
|
85 |
+
if(index>=0){
|
86 |
+
v.splice(index, 1);
|
87 |
+
}
|
88 |
+
})
|
89 |
+
}
|
90 |
+
|
91 |
+
function addContextMenuEventListener(){
|
92 |
+
if(eventListenerApplied){
|
93 |
+
return;
|
94 |
+
}
|
95 |
+
gradioApp().addEventListener("click", function(e) {
|
96 |
+
let source = e.composedPath()[0]
|
97 |
+
if(source.id && source.id.indexOf('check_progress')>-1){
|
98 |
+
return
|
99 |
+
}
|
100 |
+
|
101 |
+
let oldMenu = gradioApp().querySelector('#context-menu')
|
102 |
+
if(oldMenu){
|
103 |
+
oldMenu.remove()
|
104 |
+
}
|
105 |
+
});
|
106 |
+
gradioApp().addEventListener("contextmenu", function(e) {
|
107 |
+
let oldMenu = gradioApp().querySelector('#context-menu')
|
108 |
+
if(oldMenu){
|
109 |
+
oldMenu.remove()
|
110 |
+
}
|
111 |
+
menuSpecs.forEach(function(v,k) {
|
112 |
+
if(e.composedPath()[0].matches(k)){
|
113 |
+
showContextMenu(e,e.composedPath()[0],v)
|
114 |
+
e.preventDefault()
|
115 |
+
return
|
116 |
+
}
|
117 |
+
})
|
118 |
+
});
|
119 |
+
eventListenerApplied=true
|
120 |
+
|
121 |
+
}
|
122 |
+
|
123 |
+
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
|
124 |
+
}
|
125 |
+
|
126 |
+
initResponse = contextMenuInit();
|
127 |
+
appendContextMenuOption = initResponse[0];
|
128 |
+
removeContextMenuOption = initResponse[1];
|
129 |
+
addContextMenuEventListener = initResponse[2];
|
130 |
+
|
131 |
+
(function(){
|
132 |
+
//Start example Context Menu Items
|
133 |
+
let generateOnRepeat = function(genbuttonid,interruptbuttonid){
|
134 |
+
let genbutton = gradioApp().querySelector(genbuttonid);
|
135 |
+
let interruptbutton = gradioApp().querySelector(interruptbuttonid);
|
136 |
+
if(!interruptbutton.offsetParent){
|
137 |
+
genbutton.click();
|
138 |
+
}
|
139 |
+
clearInterval(window.generateOnRepeatInterval)
|
140 |
+
window.generateOnRepeatInterval = setInterval(function(){
|
141 |
+
if(!interruptbutton.offsetParent){
|
142 |
+
genbutton.click();
|
143 |
+
}
|
144 |
+
},
|
145 |
+
500)
|
146 |
+
}
|
147 |
+
|
148 |
+
appendContextMenuOption('#txt2img_generate','Generate forever',function(){
|
149 |
+
generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
|
150 |
+
})
|
151 |
+
appendContextMenuOption('#img2img_generate','Generate forever',function(){
|
152 |
+
generateOnRepeat('#img2img_generate','#img2img_interrupt');
|
153 |
+
})
|
154 |
+
|
155 |
+
let cancelGenerateForever = function(){
|
156 |
+
clearInterval(window.generateOnRepeatInterval)
|
157 |
+
}
|
158 |
+
|
159 |
+
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
160 |
+
appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
161 |
+
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
162 |
+
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
163 |
+
|
164 |
+
appendContextMenuOption('#roll','Roll three',
|
165 |
+
function(){
|
166 |
+
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
|
167 |
+
setTimeout(function(){rollbutton.click()},100)
|
168 |
+
setTimeout(function(){rollbutton.click()},200)
|
169 |
+
setTimeout(function(){rollbutton.click()},300)
|
170 |
+
}
|
171 |
+
)
|
172 |
+
})();
|
173 |
+
//End example Context Menu Items
|
174 |
+
|
175 |
+
onUiUpdate(function(){
|
176 |
+
addContextMenuEventListener()
|
177 |
+
});
|
javascript/dragdrop.js
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// allows drag-dropping files into gradio image elements, and also pasting images from clipboard
|
2 |
+
|
3 |
+
function isValidImageList( files ) {
|
4 |
+
return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
|
5 |
+
}
|
6 |
+
|
7 |
+
function dropReplaceImage( imgWrap, files ) {
|
8 |
+
if ( ! isValidImageList( files ) ) {
|
9 |
+
return;
|
10 |
+
}
|
11 |
+
|
12 |
+
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
|
13 |
+
const callback = () => {
|
14 |
+
const fileInput = imgWrap.querySelector('input[type="file"]');
|
15 |
+
if ( fileInput ) {
|
16 |
+
fileInput.files = files;
|
17 |
+
fileInput.dispatchEvent(new Event('change'));
|
18 |
+
}
|
19 |
+
};
|
20 |
+
|
21 |
+
if ( imgWrap.closest('#pnginfo_image') ) {
|
22 |
+
// special treatment for PNG Info tab, wait for fetch request to finish
|
23 |
+
const oldFetch = window.fetch;
|
24 |
+
window.fetch = async (input, options) => {
|
25 |
+
const response = await oldFetch(input, options);
|
26 |
+
if ( 'api/predict/' === input ) {
|
27 |
+
const content = await response.text();
|
28 |
+
window.fetch = oldFetch;
|
29 |
+
window.requestAnimationFrame( () => callback() );
|
30 |
+
return new Response(content, {
|
31 |
+
status: response.status,
|
32 |
+
statusText: response.statusText,
|
33 |
+
headers: response.headers
|
34 |
+
})
|
35 |
+
}
|
36 |
+
return response;
|
37 |
+
};
|
38 |
+
} else {
|
39 |
+
window.requestAnimationFrame( () => callback() );
|
40 |
+
}
|
41 |
+
}
|
42 |
+
|
43 |
+
window.document.addEventListener('dragover', e => {
|
44 |
+
const target = e.composedPath()[0];
|
45 |
+
const imgWrap = target.closest('[data-testid="image"]');
|
46 |
+
if ( !imgWrap ) {
|
47 |
+
return;
|
48 |
+
}
|
49 |
+
e.stopPropagation();
|
50 |
+
e.preventDefault();
|
51 |
+
e.dataTransfer.dropEffect = 'copy';
|
52 |
+
});
|
53 |
+
|
54 |
+
window.document.addEventListener('drop', e => {
|
55 |
+
const target = e.composedPath()[0];
|
56 |
+
const imgWrap = target.closest('[data-testid="image"]');
|
57 |
+
if ( !imgWrap ) {
|
58 |
+
return;
|
59 |
+
}
|
60 |
+
e.stopPropagation();
|
61 |
+
e.preventDefault();
|
62 |
+
const files = e.dataTransfer.files;
|
63 |
+
dropReplaceImage( imgWrap, files );
|
64 |
+
});
|
65 |
+
|
66 |
+
window.addEventListener('paste', e => {
|
67 |
+
const files = e.clipboardData.files;
|
68 |
+
if ( ! isValidImageList( files ) ) {
|
69 |
+
return;
|
70 |
+
}
|
71 |
+
|
72 |
+
const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
|
73 |
+
.filter(el => uiElementIsVisible(el));
|
74 |
+
if ( ! visibleImageFields.length ) {
|
75 |
+
return;
|
76 |
+
}
|
77 |
+
|
78 |
+
const firstFreeImageField = visibleImageFields
|
79 |
+
.filter(el => el.querySelector('input[type=file]'))?.[0];
|
80 |
+
|
81 |
+
dropReplaceImage(
|
82 |
+
firstFreeImageField ?
|
83 |
+
firstFreeImageField :
|
84 |
+
visibleImageFields[visibleImageFields.length - 1]
|
85 |
+
, files );
|
86 |
+
});
|
javascript/edit-attention.js
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
addEventListener('keydown', (event) => {
|
2 |
+
let target = event.originalTarget || event.composedPath()[0];
|
3 |
+
if (!target.hasAttribute("placeholder")) return;
|
4 |
+
if (!target.placeholder.toLowerCase().includes("prompt")) return;
|
5 |
+
|
6 |
+
let plus = "ArrowUp"
|
7 |
+
let minus = "ArrowDown"
|
8 |
+
if (event.key != plus && event.key != minus) return;
|
9 |
+
|
10 |
+
selectionStart = target.selectionStart;
|
11 |
+
selectionEnd = target.selectionEnd;
|
12 |
+
if(selectionStart == selectionEnd) return;
|
13 |
+
|
14 |
+
event.preventDefault();
|
15 |
+
|
16 |
+
if (selectionStart == 0 || target.value[selectionStart - 1] != "(") {
|
17 |
+
target.value = target.value.slice(0, selectionStart) +
|
18 |
+
"(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" +
|
19 |
+
target.value.slice(selectionEnd);
|
20 |
+
|
21 |
+
target.focus();
|
22 |
+
target.selectionStart = selectionStart + 1;
|
23 |
+
target.selectionEnd = selectionEnd + 1;
|
24 |
+
|
25 |
+
} else {
|
26 |
+
end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
|
27 |
+
weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
28 |
+
if (isNaN(weight)) return;
|
29 |
+
if (event.key == minus) weight -= 0.1;
|
30 |
+
if (event.key == plus) weight += 0.1;
|
31 |
+
|
32 |
+
weight = parseFloat(weight.toPrecision(12));
|
33 |
+
|
34 |
+
target.value = target.value.slice(0, selectionEnd + 1) +
|
35 |
+
weight +
|
36 |
+
target.value.slice(selectionEnd + 1 + end - 1);
|
37 |
+
|
38 |
+
target.focus();
|
39 |
+
target.selectionStart = selectionStart;
|
40 |
+
target.selectionEnd = selectionEnd;
|
41 |
+
}
|
42 |
+
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its
|
43 |
+
// internal Svelte data binding remains in sync.
|
44 |
+
target.dispatchEvent(new Event("input", { bubbles: true }));
|
45 |
+
});
|
javascript/hints.js
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// mouseover tooltips for various UI elements
|
2 |
+
|
3 |
+
titles = {
|
4 |
+
"Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
|
5 |
+
"Sampling method": "Which algorithm to use to produce the image",
|
6 |
+
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
|
7 |
+
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help",
|
8 |
+
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
|
9 |
+
|
10 |
+
"Batch count": "How many batches of images to create",
|
11 |
+
"Batch size": "How many image to create in a single batch",
|
12 |
+
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
|
13 |
+
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
|
14 |
+
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
|
15 |
+
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
|
16 |
+
"\u{1f3a8}": "Add a random artist to the prompt.",
|
17 |
+
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
18 |
+
"\u{1f4c2}": "Open images output directory",
|
19 |
+
|
20 |
+
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
21 |
+
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
22 |
+
|
23 |
+
"Just resize": "Resize image to target resolution. Unless height and width match, you will get incorrect aspect ratio.",
|
24 |
+
"Crop and resize": "Resize the image so that entirety of target resolution is filled with the image. Crop parts that stick out.",
|
25 |
+
"Resize and fill": "Resize the image so that entirety of image is inside target resolution. Fill empty space with image's colors.",
|
26 |
+
|
27 |
+
"Mask blur": "How much to blur the mask before processing, in pixels.",
|
28 |
+
"Masked content": "What to put inside the masked area before processing it with Stable Diffusion.",
|
29 |
+
"fill": "fill it with colors of the image",
|
30 |
+
"original": "keep whatever was there originally",
|
31 |
+
"latent noise": "fill it with latent space noise",
|
32 |
+
"latent nothing": "fill it with latent space zeroes",
|
33 |
+
"Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
|
34 |
+
|
35 |
+
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
|
36 |
+
"Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.",
|
37 |
+
|
38 |
+
"Skip": "Stop processing current image and continue processing.",
|
39 |
+
"Interrupt": "Stop processing images and return any results accumulated so far.",
|
40 |
+
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
|
41 |
+
|
42 |
+
"X values": "Separate values for X axis using commas.",
|
43 |
+
"Y values": "Separate values for Y axis using commas.",
|
44 |
+
|
45 |
+
"None": "Do not do anything special",
|
46 |
+
"Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)",
|
47 |
+
"X/Y plot": "Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
|
48 |
+
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
|
49 |
+
|
50 |
+
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
|
51 |
+
"Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order",
|
52 |
+
|
53 |
+
"Tiling": "Produce an image that can be tiled.",
|
54 |
+
"Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",
|
55 |
+
|
56 |
+
"Variation seed": "Seed of a different picture to be mixed into the generation.",
|
57 |
+
"Variation strength": "How strong of a variation to produce. At 0, there will be no effect. At 1, you will get the complete picture with variation seed (except for ancestral samplers, where you will just get something).",
|
58 |
+
"Resize seed from height": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
|
59 |
+
"Resize seed from width": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
|
60 |
+
|
61 |
+
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
62 |
+
|
63 |
+
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
|
64 |
+
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
|
65 |
+
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
66 |
+
|
67 |
+
"Loopback": "Process an image, use it as an input, repeat.",
|
68 |
+
"Loops": "How many times to repeat processing an image and using it as input for the next iteration",
|
69 |
+
|
70 |
+
"Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both",
|
71 |
+
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
|
72 |
+
"Apply style": "Insert selected styles into prompt fields",
|
73 |
+
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
|
74 |
+
|
75 |
+
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
|
76 |
+
|
77 |
+
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
|
78 |
+
|
79 |
+
"Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
|
80 |
+
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
|
81 |
+
|
82 |
+
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
|
83 |
+
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
|
84 |
+
|
85 |
+
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
|
86 |
+
"Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.",
|
87 |
+
|
88 |
+
"Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply."
|
89 |
+
}
|
90 |
+
|
91 |
+
|
92 |
+
onUiUpdate(function(){
|
93 |
+
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
|
94 |
+
tooltip = titles[span.textContent];
|
95 |
+
|
96 |
+
if(!tooltip){
|
97 |
+
tooltip = titles[span.value];
|
98 |
+
}
|
99 |
+
|
100 |
+
if(!tooltip){
|
101 |
+
for (const c of span.classList) {
|
102 |
+
if (c in titles) {
|
103 |
+
tooltip = titles[c];
|
104 |
+
break;
|
105 |
+
}
|
106 |
+
}
|
107 |
+
}
|
108 |
+
|
109 |
+
if(tooltip){
|
110 |
+
span.title = tooltip;
|
111 |
+
}
|
112 |
+
})
|
113 |
+
|
114 |
+
gradioApp().querySelectorAll('select').forEach(function(select){
|
115 |
+
if (select.onchange != null) return;
|
116 |
+
|
117 |
+
select.onchange = function(){
|
118 |
+
select.title = titles[select.value] || "";
|
119 |
+
}
|
120 |
+
})
|
121 |
+
})
|
javascript/imageMaskFix.js
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
|
3 |
+
* @see https://github.com/gradio-app/gradio/issues/1721
|
4 |
+
*/
|
5 |
+
window.addEventListener( 'resize', () => imageMaskResize());
|
6 |
+
function imageMaskResize() {
|
7 |
+
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
|
8 |
+
if ( ! canvases.length ) {
|
9 |
+
canvases_fixed = false;
|
10 |
+
window.removeEventListener( 'resize', imageMaskResize );
|
11 |
+
return;
|
12 |
+
}
|
13 |
+
|
14 |
+
const wrapper = canvases[0].closest('.touch-none');
|
15 |
+
const previewImage = wrapper.previousElementSibling;
|
16 |
+
|
17 |
+
if ( ! previewImage.complete ) {
|
18 |
+
previewImage.addEventListener( 'load', () => imageMaskResize());
|
19 |
+
return;
|
20 |
+
}
|
21 |
+
|
22 |
+
const w = previewImage.width;
|
23 |
+
const h = previewImage.height;
|
24 |
+
const nw = previewImage.naturalWidth;
|
25 |
+
const nh = previewImage.naturalHeight;
|
26 |
+
const portrait = nh > nw;
|
27 |
+
const factor = portrait;
|
28 |
+
|
29 |
+
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
|
30 |
+
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
|
31 |
+
|
32 |
+
wrapper.style.width = `${wW}px`;
|
33 |
+
wrapper.style.height = `${wH}px`;
|
34 |
+
wrapper.style.left = `${(w-wW)/2}px`;
|
35 |
+
wrapper.style.top = `${(h-wH)/2}px`;
|
36 |
+
|
37 |
+
canvases.forEach( c => {
|
38 |
+
c.style.width = c.style.height = '';
|
39 |
+
c.style.maxWidth = '100%';
|
40 |
+
c.style.maxHeight = '100%';
|
41 |
+
c.style.objectFit = 'contain';
|
42 |
+
});
|
43 |
+
}
|
44 |
+
|
45 |
+
onUiUpdate(() => imageMaskResize());
|
javascript/imageviewer.js
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// A full size 'lightbox' preview modal shown when left clicking on gallery previews
|
2 |
+
function closeModal() {
|
3 |
+
gradioApp().getElementById("lightboxModal").style.display = "none";
|
4 |
+
}
|
5 |
+
|
6 |
+
function showModal(event) {
|
7 |
+
const source = event.target || event.srcElement;
|
8 |
+
const modalImage = gradioApp().getElementById("modalImage")
|
9 |
+
const lb = gradioApp().getElementById("lightboxModal")
|
10 |
+
modalImage.src = source.src
|
11 |
+
if (modalImage.style.display === 'none') {
|
12 |
+
lb.style.setProperty('background-image', 'url(' + source.src + ')');
|
13 |
+
}
|
14 |
+
lb.style.display = "block";
|
15 |
+
lb.focus()
|
16 |
+
event.stopPropagation()
|
17 |
+
}
|
18 |
+
|
19 |
+
function negmod(n, m) {
|
20 |
+
return ((n % m) + m) % m;
|
21 |
+
}
|
22 |
+
|
23 |
+
function updateOnBackgroundChange() {
|
24 |
+
const modalImage = gradioApp().getElementById("modalImage")
|
25 |
+
if (modalImage && modalImage.offsetParent) {
|
26 |
+
let allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
|
27 |
+
let currentButton = null
|
28 |
+
allcurrentButtons.forEach(function(elem) {
|
29 |
+
if (elem.parentElement.offsetParent) {
|
30 |
+
currentButton = elem;
|
31 |
+
}
|
32 |
+
})
|
33 |
+
|
34 |
+
if (modalImage.src != currentButton.children[0].src) {
|
35 |
+
modalImage.src = currentButton.children[0].src;
|
36 |
+
if (modalImage.style.display === 'none') {
|
37 |
+
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
38 |
+
}
|
39 |
+
}
|
40 |
+
}
|
41 |
+
}
|
42 |
+
|
43 |
+
function modalImageSwitch(offset) {
|
44 |
+
var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all")
|
45 |
+
var galleryButtons = []
|
46 |
+
allgalleryButtons.forEach(function(elem) {
|
47 |
+
if (elem.parentElement.offsetParent) {
|
48 |
+
galleryButtons.push(elem);
|
49 |
+
}
|
50 |
+
})
|
51 |
+
|
52 |
+
if (galleryButtons.length > 1) {
|
53 |
+
var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
|
54 |
+
var currentButton = null
|
55 |
+
allcurrentButtons.forEach(function(elem) {
|
56 |
+
if (elem.parentElement.offsetParent) {
|
57 |
+
currentButton = elem;
|
58 |
+
}
|
59 |
+
})
|
60 |
+
|
61 |
+
var result = -1
|
62 |
+
galleryButtons.forEach(function(v, i) {
|
63 |
+
if (v == currentButton) {
|
64 |
+
result = i
|
65 |
+
}
|
66 |
+
})
|
67 |
+
|
68 |
+
if (result != -1) {
|
69 |
+
nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
|
70 |
+
nextButton.click()
|
71 |
+
const modalImage = gradioApp().getElementById("modalImage");
|
72 |
+
const modal = gradioApp().getElementById("lightboxModal");
|
73 |
+
modalImage.src = nextButton.children[0].src;
|
74 |
+
if (modalImage.style.display === 'none') {
|
75 |
+
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
76 |
+
}
|
77 |
+
setTimeout(function() {
|
78 |
+
modal.focus()
|
79 |
+
}, 10)
|
80 |
+
}
|
81 |
+
}
|
82 |
+
}
|
83 |
+
|
84 |
+
function modalNextImage(event) {
|
85 |
+
modalImageSwitch(1)
|
86 |
+
event.stopPropagation()
|
87 |
+
}
|
88 |
+
|
89 |
+
function modalPrevImage(event) {
|
90 |
+
modalImageSwitch(-1)
|
91 |
+
event.stopPropagation()
|
92 |
+
}
|
93 |
+
|
94 |
+
function modalKeyHandler(event) {
|
95 |
+
switch (event.key) {
|
96 |
+
case "ArrowLeft":
|
97 |
+
modalPrevImage(event)
|
98 |
+
break;
|
99 |
+
case "ArrowRight":
|
100 |
+
modalNextImage(event)
|
101 |
+
break;
|
102 |
+
case "Escape":
|
103 |
+
closeModal();
|
104 |
+
break;
|
105 |
+
}
|
106 |
+
}
|
107 |
+
|
108 |
+
function showGalleryImage() {
|
109 |
+
setTimeout(function() {
|
110 |
+
fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain')
|
111 |
+
|
112 |
+
if (fullImg_preview != null) {
|
113 |
+
fullImg_preview.forEach(function function_name(e) {
|
114 |
+
if (e.dataset.modded)
|
115 |
+
return;
|
116 |
+
e.dataset.modded = true;
|
117 |
+
if(e && e.parentElement.tagName == 'DIV'){
|
118 |
+
e.style.cursor='pointer'
|
119 |
+
e.addEventListener('click', function (evt) {
|
120 |
+
if(!opts.js_modal_lightbox) return;
|
121 |
+
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
122 |
+
showModal(evt)
|
123 |
+
}, true);
|
124 |
+
}
|
125 |
+
});
|
126 |
+
}
|
127 |
+
|
128 |
+
}, 100);
|
129 |
+
}
|
130 |
+
|
131 |
+
function modalZoomSet(modalImage, enable) {
|
132 |
+
if (enable) {
|
133 |
+
modalImage.classList.add('modalImageFullscreen');
|
134 |
+
} else {
|
135 |
+
modalImage.classList.remove('modalImageFullscreen');
|
136 |
+
}
|
137 |
+
}
|
138 |
+
|
139 |
+
function modalZoomToggle(event) {
|
140 |
+
modalImage = gradioApp().getElementById("modalImage");
|
141 |
+
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
|
142 |
+
event.stopPropagation()
|
143 |
+
}
|
144 |
+
|
145 |
+
function modalTileImageToggle(event) {
|
146 |
+
const modalImage = gradioApp().getElementById("modalImage");
|
147 |
+
const modal = gradioApp().getElementById("lightboxModal");
|
148 |
+
const isTiling = modalImage.style.display === 'none';
|
149 |
+
if (isTiling) {
|
150 |
+
modalImage.style.display = 'block';
|
151 |
+
modal.style.setProperty('background-image', 'none')
|
152 |
+
} else {
|
153 |
+
modalImage.style.display = 'none';
|
154 |
+
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
155 |
+
}
|
156 |
+
|
157 |
+
event.stopPropagation()
|
158 |
+
}
|
159 |
+
|
160 |
+
function galleryImageHandler(e) {
|
161 |
+
if (e && e.parentElement.tagName == 'BUTTON') {
|
162 |
+
e.onclick = showGalleryImage;
|
163 |
+
}
|
164 |
+
}
|
165 |
+
|
166 |
+
onUiUpdate(function() {
|
167 |
+
fullImg_preview = gradioApp().querySelectorAll('img.w-full')
|
168 |
+
if (fullImg_preview != null) {
|
169 |
+
fullImg_preview.forEach(galleryImageHandler);
|
170 |
+
}
|
171 |
+
updateOnBackgroundChange();
|
172 |
+
})
|
173 |
+
|
174 |
+
document.addEventListener("DOMContentLoaded", function() {
|
175 |
+
const modalFragment = document.createDocumentFragment();
|
176 |
+
const modal = document.createElement('div')
|
177 |
+
modal.onclick = closeModal;
|
178 |
+
modal.id = "lightboxModal";
|
179 |
+
modal.tabIndex = 0
|
180 |
+
modal.addEventListener('keydown', modalKeyHandler, true)
|
181 |
+
|
182 |
+
const modalControls = document.createElement('div')
|
183 |
+
modalControls.className = 'modalControls gradio-container';
|
184 |
+
modal.append(modalControls);
|
185 |
+
|
186 |
+
const modalZoom = document.createElement('span')
|
187 |
+
modalZoom.className = 'modalZoom cursor';
|
188 |
+
modalZoom.innerHTML = '⤡'
|
189 |
+
modalZoom.addEventListener('click', modalZoomToggle, true)
|
190 |
+
modalZoom.title = "Toggle zoomed view";
|
191 |
+
modalControls.appendChild(modalZoom)
|
192 |
+
|
193 |
+
const modalTileImage = document.createElement('span')
|
194 |
+
modalTileImage.className = 'modalTileImage cursor';
|
195 |
+
modalTileImage.innerHTML = '⊞'
|
196 |
+
modalTileImage.addEventListener('click', modalTileImageToggle, true)
|
197 |
+
modalTileImage.title = "Preview tiling";
|
198 |
+
modalControls.appendChild(modalTileImage)
|
199 |
+
|
200 |
+
const modalClose = document.createElement('span')
|
201 |
+
modalClose.className = 'modalClose cursor';
|
202 |
+
modalClose.innerHTML = '×'
|
203 |
+
modalClose.onclick = closeModal;
|
204 |
+
modalClose.title = "Close image viewer";
|
205 |
+
modalControls.appendChild(modalClose)
|
206 |
+
|
207 |
+
const modalImage = document.createElement('img')
|
208 |
+
modalImage.id = 'modalImage';
|
209 |
+
modalImage.onclick = closeModal;
|
210 |
+
modalImage.tabIndex = 0
|
211 |
+
modalImage.addEventListener('keydown', modalKeyHandler, true)
|
212 |
+
modal.appendChild(modalImage)
|
213 |
+
|
214 |
+
const modalPrev = document.createElement('a')
|
215 |
+
modalPrev.className = 'modalPrev';
|
216 |
+
modalPrev.innerHTML = '❮'
|
217 |
+
modalPrev.tabIndex = 0
|
218 |
+
modalPrev.addEventListener('click', modalPrevImage, true);
|
219 |
+
modalPrev.addEventListener('keydown', modalKeyHandler, true)
|
220 |
+
modal.appendChild(modalPrev)
|
221 |
+
|
222 |
+
const modalNext = document.createElement('a')
|
223 |
+
modalNext.className = 'modalNext';
|
224 |
+
modalNext.innerHTML = '❯'
|
225 |
+
modalNext.tabIndex = 0
|
226 |
+
modalNext.addEventListener('click', modalNextImage, true);
|
227 |
+
modalNext.addEventListener('keydown', modalKeyHandler, true)
|
228 |
+
|
229 |
+
modal.appendChild(modalNext)
|
230 |
+
|
231 |
+
|
232 |
+
gradioApp().getRootNode().appendChild(modal)
|
233 |
+
|
234 |
+
document.body.appendChild(modalFragment);
|
235 |
+
|
236 |
+
});
|
javascript/notification.js
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Monitors the gallery and sends a browser notification when the leading image is new.
|
2 |
+
|
3 |
+
let lastHeadImg = null;
|
4 |
+
|
5 |
+
notificationButton = null
|
6 |
+
|
7 |
+
onUiUpdate(function(){
|
8 |
+
if(notificationButton == null){
|
9 |
+
notificationButton = gradioApp().getElementById('request_notifications')
|
10 |
+
|
11 |
+
if(notificationButton != null){
|
12 |
+
notificationButton.addEventListener('click', function (evt) {
|
13 |
+
Notification.requestPermission();
|
14 |
+
},true);
|
15 |
+
}
|
16 |
+
}
|
17 |
+
|
18 |
+
const galleryPreviews = gradioApp().querySelectorAll('img.h-full.w-full.overflow-hidden');
|
19 |
+
|
20 |
+
if (galleryPreviews == null) return;
|
21 |
+
|
22 |
+
const headImg = galleryPreviews[0]?.src;
|
23 |
+
|
24 |
+
if (headImg == null || headImg == lastHeadImg) return;
|
25 |
+
|
26 |
+
lastHeadImg = headImg;
|
27 |
+
|
28 |
+
// play notification sound if available
|
29 |
+
gradioApp().querySelector('#audio_notification audio')?.play();
|
30 |
+
|
31 |
+
if (document.hasFocus()) return;
|
32 |
+
|
33 |
+
// Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated.
|
34 |
+
const imgs = new Set(Array.from(galleryPreviews).map(img => img.src));
|
35 |
+
|
36 |
+
const notification = new Notification(
|
37 |
+
'Stable Diffusion',
|
38 |
+
{
|
39 |
+
body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,
|
40 |
+
icon: headImg,
|
41 |
+
image: headImg,
|
42 |
+
}
|
43 |
+
);
|
44 |
+
|
45 |
+
notification.onclick = function(_){
|
46 |
+
parent.focus();
|
47 |
+
this.close();
|
48 |
+
};
|
49 |
+
});
|
javascript/progressbar.js
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// code related to showing and updating progressbar shown as the image is being made
|
2 |
+
global_progressbars = {}
|
3 |
+
|
4 |
+
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
|
5 |
+
var progressbar = gradioApp().getElementById(id_progressbar)
|
6 |
+
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
|
7 |
+
var interrupt = gradioApp().getElementById(id_interrupt)
|
8 |
+
|
9 |
+
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
|
10 |
+
if(progressbar.innerText){
|
11 |
+
let newtitle = 'Stable Diffusion - ' + progressbar.innerText
|
12 |
+
if(document.title != newtitle){
|
13 |
+
document.title = newtitle;
|
14 |
+
}
|
15 |
+
}else{
|
16 |
+
let newtitle = 'Stable Diffusion'
|
17 |
+
if(document.title != newtitle){
|
18 |
+
document.title = newtitle;
|
19 |
+
}
|
20 |
+
}
|
21 |
+
}
|
22 |
+
|
23 |
+
if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
|
24 |
+
global_progressbars[id_progressbar] = progressbar
|
25 |
+
|
26 |
+
var mutationObserver = new MutationObserver(function(m){
|
27 |
+
preview = gradioApp().getElementById(id_preview)
|
28 |
+
gallery = gradioApp().getElementById(id_gallery)
|
29 |
+
|
30 |
+
if(preview != null && gallery != null){
|
31 |
+
preview.style.width = gallery.clientWidth + "px"
|
32 |
+
preview.style.height = gallery.clientHeight + "px"
|
33 |
+
|
34 |
+
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
35 |
+
if(!progressDiv){
|
36 |
+
if (skip) {
|
37 |
+
skip.style.display = "none"
|
38 |
+
}
|
39 |
+
interrupt.style.display = "none"
|
40 |
+
}
|
41 |
+
}
|
42 |
+
|
43 |
+
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
|
44 |
+
});
|
45 |
+
mutationObserver.observe( progressbar, { childList:true, subtree:true })
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
onUiUpdate(function(){
|
50 |
+
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
51 |
+
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
52 |
+
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery')
|
53 |
+
})
|
54 |
+
|
55 |
+
function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){
|
56 |
+
btn = gradioApp().getElementById(id_part+"_check_progress");
|
57 |
+
if(btn==null) return;
|
58 |
+
|
59 |
+
btn.click();
|
60 |
+
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
61 |
+
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
|
62 |
+
var interrupt = gradioApp().getElementById(id_interrupt)
|
63 |
+
if(progressDiv && interrupt){
|
64 |
+
if (skip) {
|
65 |
+
skip.style.display = "block"
|
66 |
+
}
|
67 |
+
interrupt.style.display = "block"
|
68 |
+
}
|
69 |
+
}
|
70 |
+
|
71 |
+
function requestProgress(id_part){
|
72 |
+
btn = gradioApp().getElementById(id_part+"_check_progress_initial");
|
73 |
+
if(btn==null) return;
|
74 |
+
|
75 |
+
btn.click();
|
76 |
+
}
|
javascript/textualInversion.js
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
function start_training_textual_inversion(){
|
4 |
+
requestProgress('ti')
|
5 |
+
gradioApp().querySelector('#ti_error').innerHTML=''
|
6 |
+
|
7 |
+
return args_to_array(arguments)
|
8 |
+
}
|
javascript/ui.js
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// various functions for interation with ui.py not large enough to warrant putting them in separate files
|
2 |
+
|
3 |
+
function selected_gallery_index(){
|
4 |
+
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item')
|
5 |
+
var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2')
|
6 |
+
|
7 |
+
var result = -1
|
8 |
+
buttons.forEach(function(v, i){ if(v==button) { result = i } })
|
9 |
+
|
10 |
+
return result
|
11 |
+
}
|
12 |
+
|
13 |
+
function extract_image_from_gallery(gallery){
|
14 |
+
if(gallery.length == 1){
|
15 |
+
return gallery[0]
|
16 |
+
}
|
17 |
+
|
18 |
+
index = selected_gallery_index()
|
19 |
+
|
20 |
+
if (index < 0 || index >= gallery.length){
|
21 |
+
return [null]
|
22 |
+
}
|
23 |
+
|
24 |
+
return gallery[index];
|
25 |
+
}
|
26 |
+
|
27 |
+
function args_to_array(args){
|
28 |
+
res = []
|
29 |
+
for(var i=0;i<args.length;i++){
|
30 |
+
res.push(args[i])
|
31 |
+
}
|
32 |
+
return res
|
33 |
+
}
|
34 |
+
|
35 |
+
function switch_to_txt2img(){
|
36 |
+
gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click();
|
37 |
+
|
38 |
+
return args_to_array(arguments);
|
39 |
+
}
|
40 |
+
|
41 |
+
function switch_to_img2img_img2img(){
|
42 |
+
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
43 |
+
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
|
44 |
+
|
45 |
+
return args_to_array(arguments);
|
46 |
+
}
|
47 |
+
|
48 |
+
function switch_to_img2img_inpaint(){
|
49 |
+
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
50 |
+
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
|
51 |
+
|
52 |
+
return args_to_array(arguments);
|
53 |
+
}
|
54 |
+
|
55 |
+
function switch_to_extras(){
|
56 |
+
gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click();
|
57 |
+
|
58 |
+
return args_to_array(arguments);
|
59 |
+
}
|
60 |
+
|
61 |
+
function extract_image_from_gallery_txt2img(gallery){
|
62 |
+
switch_to_txt2img()
|
63 |
+
return extract_image_from_gallery(gallery);
|
64 |
+
}
|
65 |
+
|
66 |
+
function extract_image_from_gallery_img2img(gallery){
|
67 |
+
switch_to_img2img_img2img()
|
68 |
+
return extract_image_from_gallery(gallery);
|
69 |
+
}
|
70 |
+
|
71 |
+
function extract_image_from_gallery_inpaint(gallery){
|
72 |
+
switch_to_img2img_inpaint()
|
73 |
+
return extract_image_from_gallery(gallery);
|
74 |
+
}
|
75 |
+
|
76 |
+
function extract_image_from_gallery_extras(gallery){
|
77 |
+
switch_to_extras()
|
78 |
+
return extract_image_from_gallery(gallery);
|
79 |
+
}
|
80 |
+
|
81 |
+
function get_tab_index(tabId){
|
82 |
+
var res = 0
|
83 |
+
|
84 |
+
gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i){
|
85 |
+
if(button.className.indexOf('bg-white') != -1)
|
86 |
+
res = i
|
87 |
+
})
|
88 |
+
|
89 |
+
return res
|
90 |
+
}
|
91 |
+
|
92 |
+
function create_tab_index_args(tabId, args){
|
93 |
+
var res = []
|
94 |
+
for(var i=0; i<args.length; i++){
|
95 |
+
res.push(args[i])
|
96 |
+
}
|
97 |
+
|
98 |
+
res[0] = get_tab_index(tabId)
|
99 |
+
|
100 |
+
return res
|
101 |
+
}
|
102 |
+
|
103 |
+
function get_extras_tab_index(){
|
104 |
+
const [,,...args] = [...arguments]
|
105 |
+
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
|
106 |
+
}
|
107 |
+
|
108 |
+
function create_submit_args(args){
|
109 |
+
res = []
|
110 |
+
for(var i=0;i<args.length;i++){
|
111 |
+
res.push(args[i])
|
112 |
+
}
|
113 |
+
|
114 |
+
// As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
|
115 |
+
// This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
|
116 |
+
// I don't know why gradio is seding outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
|
117 |
+
// If gradio at some point stops sending outputs, this may break something
|
118 |
+
if(Array.isArray(res[res.length - 3])){
|
119 |
+
res[res.length - 3] = null
|
120 |
+
}
|
121 |
+
|
122 |
+
return res
|
123 |
+
}
|
124 |
+
|
125 |
+
function submit(){
|
126 |
+
requestProgress('txt2img')
|
127 |
+
|
128 |
+
return create_submit_args(arguments)
|
129 |
+
}
|
130 |
+
|
131 |
+
function submit_img2img(){
|
132 |
+
requestProgress('img2img')
|
133 |
+
|
134 |
+
res = create_submit_args(arguments)
|
135 |
+
|
136 |
+
res[0] = get_tab_index('mode_img2img')
|
137 |
+
|
138 |
+
return res
|
139 |
+
}
|
140 |
+
|
141 |
+
|
142 |
+
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
143 |
+
name_ = prompt('Style name:')
|
144 |
+
return name_ === null ? [null, null, null]: [name_, prompt_text, negative_prompt_text]
|
145 |
+
}
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
opts = {}
|
150 |
+
function apply_settings(jsdata){
|
151 |
+
console.log(jsdata)
|
152 |
+
|
153 |
+
opts = JSON.parse(jsdata)
|
154 |
+
|
155 |
+
return jsdata
|
156 |
+
}
|
157 |
+
|
158 |
+
onUiUpdate(function(){
|
159 |
+
if(Object.keys(opts).length != 0) return;
|
160 |
+
|
161 |
+
json_elem = gradioApp().getElementById('settings_json')
|
162 |
+
if(json_elem == null) return;
|
163 |
+
|
164 |
+
textarea = json_elem.querySelector('textarea')
|
165 |
+
jsdata = textarea.value
|
166 |
+
opts = JSON.parse(jsdata)
|
167 |
+
|
168 |
+
|
169 |
+
Object.defineProperty(textarea, 'value', {
|
170 |
+
set: function(newValue) {
|
171 |
+
var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
|
172 |
+
var oldValue = valueProp.get.call(textarea);
|
173 |
+
valueProp.set.call(textarea, newValue);
|
174 |
+
|
175 |
+
if (oldValue != newValue) {
|
176 |
+
opts = JSON.parse(textarea.value)
|
177 |
+
}
|
178 |
+
},
|
179 |
+
get: function() {
|
180 |
+
var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
|
181 |
+
return valueProp.get.call(textarea);
|
182 |
+
}
|
183 |
+
});
|
184 |
+
|
185 |
+
json_elem.parentElement.style.display="none"
|
186 |
+
|
187 |
+
if (!txt2img_textarea) {
|
188 |
+
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
|
189 |
+
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
|
190 |
+
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
|
191 |
+
}
|
192 |
+
if (!img2img_textarea) {
|
193 |
+
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
194 |
+
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
195 |
+
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
|
196 |
+
}
|
197 |
+
})
|
198 |
+
|
199 |
+
let txt2img_textarea, img2img_textarea = undefined;
|
200 |
+
let wait_time = 800
|
201 |
+
let token_timeout;
|
202 |
+
|
203 |
+
function update_txt2img_tokens(...args) {
|
204 |
+
update_token_counter("txt2img_token_button")
|
205 |
+
if (args.length == 2)
|
206 |
+
return args[0]
|
207 |
+
return args;
|
208 |
+
}
|
209 |
+
|
210 |
+
function update_img2img_tokens(...args) {
|
211 |
+
update_token_counter("img2img_token_button")
|
212 |
+
if (args.length == 2)
|
213 |
+
return args[0]
|
214 |
+
return args;
|
215 |
+
}
|
216 |
+
|
217 |
+
function update_token_counter(button_id) {
|
218 |
+
if (token_timeout)
|
219 |
+
clearTimeout(token_timeout);
|
220 |
+
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
221 |
+
}
|
222 |
+
|
223 |
+
function submit_prompt(event, generate_button_id) {
|
224 |
+
if (event.altKey && event.keyCode === 13) {
|
225 |
+
event.preventDefault();
|
226 |
+
gradioApp().getElementById(generate_button_id).click();
|
227 |
+
return;
|
228 |
+
}
|
229 |
+
}
|
230 |
+
|
231 |
+
function restart_reload(){
|
232 |
+
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
233 |
+
setTimeout(function(){location.reload()},2000)
|
234 |
+
}
|
launch.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this scripts installs necessary requirements and launches main program in webui.py
|
2 |
+
import subprocess
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import importlib.util
|
6 |
+
import shlex
|
7 |
+
import platform
|
8 |
+
|
9 |
+
dir_repos = "repositories"
|
10 |
+
python = sys.executable
|
11 |
+
git = os.environ.get('GIT', "git")
|
12 |
+
|
13 |
+
|
14 |
+
def extract_arg(args, name):
|
15 |
+
return [x for x in args if x != name], name in args
|
16 |
+
|
17 |
+
|
18 |
+
def run(command, desc=None, errdesc=None):
|
19 |
+
if desc is not None:
|
20 |
+
print(desc)
|
21 |
+
|
22 |
+
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
23 |
+
|
24 |
+
if result.returncode != 0:
|
25 |
+
|
26 |
+
message = f"""{errdesc or 'Error running command'}.
|
27 |
+
Command: {command}
|
28 |
+
Error code: {result.returncode}
|
29 |
+
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
|
30 |
+
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
|
31 |
+
"""
|
32 |
+
raise RuntimeError(message)
|
33 |
+
|
34 |
+
return result.stdout.decode(encoding="utf8", errors="ignore")
|
35 |
+
|
36 |
+
|
37 |
+
def check_run(command):
|
38 |
+
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
39 |
+
return result.returncode == 0
|
40 |
+
|
41 |
+
|
42 |
+
def is_installed(package):
|
43 |
+
try:
|
44 |
+
spec = importlib.util.find_spec(package)
|
45 |
+
except ModuleNotFoundError:
|
46 |
+
return False
|
47 |
+
|
48 |
+
return spec is not None
|
49 |
+
|
50 |
+
|
51 |
+
def repo_dir(name):
|
52 |
+
return os.path.join(dir_repos, name)
|
53 |
+
|
54 |
+
|
55 |
+
def run_python(code, desc=None, errdesc=None):
|
56 |
+
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
57 |
+
|
58 |
+
|
59 |
+
def run_pip(args, desc=None):
|
60 |
+
return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
61 |
+
|
62 |
+
|
63 |
+
def check_run_python(code):
|
64 |
+
return check_run(f'"{python}" -c "{code}"')
|
65 |
+
|
66 |
+
|
67 |
+
def git_clone(url, dir, name, commithash=None):
|
68 |
+
# TODO clone into temporary dir and move if successful
|
69 |
+
|
70 |
+
if os.path.exists(dir):
|
71 |
+
if commithash is None:
|
72 |
+
return
|
73 |
+
|
74 |
+
current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
75 |
+
if current_hash == commithash:
|
76 |
+
return
|
77 |
+
|
78 |
+
run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
79 |
+
run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commint for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
80 |
+
return
|
81 |
+
|
82 |
+
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
83 |
+
|
84 |
+
if commithash is not None:
|
85 |
+
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
86 |
+
|
87 |
+
|
88 |
+
def prepare_enviroment():
|
89 |
+
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
|
90 |
+
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
91 |
+
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
92 |
+
|
93 |
+
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
94 |
+
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
95 |
+
|
96 |
+
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
97 |
+
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
98 |
+
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878")
|
99 |
+
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
100 |
+
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
101 |
+
|
102 |
+
args = shlex.split(commandline_args)
|
103 |
+
|
104 |
+
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
|
105 |
+
xformers = '--xformers' in args
|
106 |
+
deepdanbooru = '--deepdanbooru' in args
|
107 |
+
ngrok = '--ngrok' in args
|
108 |
+
|
109 |
+
try:
|
110 |
+
commit = run(f"{git} rev-parse HEAD").strip()
|
111 |
+
except Exception:
|
112 |
+
commit = "<none>"
|
113 |
+
|
114 |
+
print(f"Python {sys.version}")
|
115 |
+
print(f"Commit hash: {commit}")
|
116 |
+
|
117 |
+
if not is_installed("torch") or not is_installed("torchvision"):
|
118 |
+
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
|
119 |
+
|
120 |
+
if not skip_torch_cuda_test:
|
121 |
+
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
|
122 |
+
|
123 |
+
if not is_installed("gfpgan"):
|
124 |
+
run_pip(f"install {gfpgan_package}", "gfpgan")
|
125 |
+
|
126 |
+
if not is_installed("clip"):
|
127 |
+
run_pip(f"install {clip_package}", "clip")
|
128 |
+
|
129 |
+
if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"):
|
130 |
+
if platform.system() == "Windows":
|
131 |
+
run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/c/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
|
132 |
+
elif platform.system() == "Linux":
|
133 |
+
run_pip("install xformers", "xformers")
|
134 |
+
|
135 |
+
if not is_installed("deepdanbooru") and deepdanbooru:
|
136 |
+
run_pip("install git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru")
|
137 |
+
|
138 |
+
if not is_installed("pyngrok") and ngrok:
|
139 |
+
run_pip("install pyngrok", "ngrok")
|
140 |
+
|
141 |
+
os.makedirs(dir_repos, exist_ok=True)
|
142 |
+
|
143 |
+
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
144 |
+
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
145 |
+
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
146 |
+
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
147 |
+
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
148 |
+
|
149 |
+
if not is_installed("lpips"):
|
150 |
+
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
|
151 |
+
|
152 |
+
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
|
153 |
+
|
154 |
+
sys.argv += args
|
155 |
+
|
156 |
+
if "--exit" in args:
|
157 |
+
print("Exiting because of --exit argument")
|
158 |
+
exit(0)
|
159 |
+
|
160 |
+
|
161 |
+
def start_webui():
|
162 |
+
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
|
163 |
+
import webui
|
164 |
+
webui.webui()
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == "__main__":
|
168 |
+
prepare_enviroment()
|
169 |
+
start_webui()
|
modules/artists.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import csv
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
Artist = namedtuple("Artist", ['name', 'weight', 'category'])
|
6 |
+
|
7 |
+
|
8 |
+
class ArtistsDatabase:
|
9 |
+
def __init__(self, filename):
|
10 |
+
self.cats = set()
|
11 |
+
self.artists = []
|
12 |
+
|
13 |
+
if not os.path.exists(filename):
|
14 |
+
return
|
15 |
+
|
16 |
+
with open(filename, "r", newline='', encoding="utf8") as file:
|
17 |
+
reader = csv.DictReader(file)
|
18 |
+
|
19 |
+
for row in reader:
|
20 |
+
artist = Artist(row["artist"], float(row["score"]), row["category"])
|
21 |
+
self.artists.append(artist)
|
22 |
+
self.cats.add(artist.category)
|
23 |
+
|
24 |
+
def categories(self):
|
25 |
+
return sorted(self.cats)
|
modules/bsrgan_model.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import PIL.Image
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from basicsr.utils.download_util import load_file_from_url
|
9 |
+
|
10 |
+
import modules.upscaler
|
11 |
+
from modules import devices, modelloader
|
12 |
+
from modules.bsrgan_model_arch import RRDBNet
|
13 |
+
|
14 |
+
|
15 |
+
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
16 |
+
def __init__(self, dirname):
|
17 |
+
self.name = "BSRGAN"
|
18 |
+
self.model_name = "BSRGAN 4x"
|
19 |
+
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
20 |
+
self.user_path = dirname
|
21 |
+
super().__init__()
|
22 |
+
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
23 |
+
scalers = []
|
24 |
+
if len(model_paths) == 0:
|
25 |
+
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
26 |
+
scalers.append(scaler_data)
|
27 |
+
for file in model_paths:
|
28 |
+
if "http" in file:
|
29 |
+
name = self.model_name
|
30 |
+
else:
|
31 |
+
name = modelloader.friendly_name(file)
|
32 |
+
try:
|
33 |
+
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
34 |
+
scalers.append(scaler_data)
|
35 |
+
except Exception:
|
36 |
+
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
37 |
+
print(traceback.format_exc(), file=sys.stderr)
|
38 |
+
self.scalers = scalers
|
39 |
+
|
40 |
+
def do_upscale(self, img: PIL.Image, selected_file):
|
41 |
+
torch.cuda.empty_cache()
|
42 |
+
model = self.load_model(selected_file)
|
43 |
+
if model is None:
|
44 |
+
return img
|
45 |
+
model.to(devices.device_bsrgan)
|
46 |
+
torch.cuda.empty_cache()
|
47 |
+
img = np.array(img)
|
48 |
+
img = img[:, :, ::-1]
|
49 |
+
img = np.moveaxis(img, 2, 0) / 255
|
50 |
+
img = torch.from_numpy(img).float()
|
51 |
+
img = img.unsqueeze(0).to(devices.device_bsrgan)
|
52 |
+
with torch.no_grad():
|
53 |
+
output = model(img)
|
54 |
+
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
55 |
+
output = 255. * np.moveaxis(output, 0, 2)
|
56 |
+
output = output.astype(np.uint8)
|
57 |
+
output = output[:, :, ::-1]
|
58 |
+
torch.cuda.empty_cache()
|
59 |
+
return PIL.Image.fromarray(output, 'RGB')
|
60 |
+
|
61 |
+
def load_model(self, path: str):
|
62 |
+
if "http" in path:
|
63 |
+
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
64 |
+
progress=True)
|
65 |
+
else:
|
66 |
+
filename = path
|
67 |
+
if not os.path.exists(filename) or filename is None:
|
68 |
+
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
|
69 |
+
return None
|
70 |
+
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
|
71 |
+
model.load_state_dict(torch.load(filename), strict=True)
|
72 |
+
model.eval()
|
73 |
+
for k, v in model.named_parameters():
|
74 |
+
v.requires_grad = False
|
75 |
+
return model
|
76 |
+
|
modules/bsrgan_model_arch.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.nn.init as init
|
6 |
+
|
7 |
+
|
8 |
+
def initialize_weights(net_l, scale=1):
|
9 |
+
if not isinstance(net_l, list):
|
10 |
+
net_l = [net_l]
|
11 |
+
for net in net_l:
|
12 |
+
for m in net.modules():
|
13 |
+
if isinstance(m, nn.Conv2d):
|
14 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
15 |
+
m.weight.data *= scale # for residual block
|
16 |
+
if m.bias is not None:
|
17 |
+
m.bias.data.zero_()
|
18 |
+
elif isinstance(m, nn.Linear):
|
19 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
20 |
+
m.weight.data *= scale
|
21 |
+
if m.bias is not None:
|
22 |
+
m.bias.data.zero_()
|
23 |
+
elif isinstance(m, nn.BatchNorm2d):
|
24 |
+
init.constant_(m.weight, 1)
|
25 |
+
init.constant_(m.bias.data, 0.0)
|
26 |
+
|
27 |
+
|
28 |
+
def make_layer(block, n_layers):
|
29 |
+
layers = []
|
30 |
+
for _ in range(n_layers):
|
31 |
+
layers.append(block())
|
32 |
+
return nn.Sequential(*layers)
|
33 |
+
|
34 |
+
|
35 |
+
class ResidualDenseBlock_5C(nn.Module):
|
36 |
+
def __init__(self, nf=64, gc=32, bias=True):
|
37 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
38 |
+
# gc: growth channel, i.e. intermediate channels
|
39 |
+
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
40 |
+
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
41 |
+
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
42 |
+
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
43 |
+
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
44 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
45 |
+
|
46 |
+
# initialization
|
47 |
+
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x1 = self.lrelu(self.conv1(x))
|
51 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
52 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
53 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
54 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
55 |
+
return x5 * 0.2 + x
|
56 |
+
|
57 |
+
|
58 |
+
class RRDB(nn.Module):
|
59 |
+
'''Residual in Residual Dense Block'''
|
60 |
+
|
61 |
+
def __init__(self, nf, gc=32):
|
62 |
+
super(RRDB, self).__init__()
|
63 |
+
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
64 |
+
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
65 |
+
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
out = self.RDB1(x)
|
69 |
+
out = self.RDB2(out)
|
70 |
+
out = self.RDB3(out)
|
71 |
+
return out * 0.2 + x
|
72 |
+
|
73 |
+
|
74 |
+
class RRDBNet(nn.Module):
|
75 |
+
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
76 |
+
super(RRDBNet, self).__init__()
|
77 |
+
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
78 |
+
self.sf = sf
|
79 |
+
|
80 |
+
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
81 |
+
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
82 |
+
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
83 |
+
#### upsampling
|
84 |
+
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
85 |
+
if self.sf==4:
|
86 |
+
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
87 |
+
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
88 |
+
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
89 |
+
|
90 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
fea = self.conv_first(x)
|
94 |
+
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
95 |
+
fea = fea + trunk
|
96 |
+
|
97 |
+
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
98 |
+
if self.sf==4:
|
99 |
+
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
100 |
+
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
101 |
+
|
102 |
+
return out
|
modules/codeformer/codeformer_arch.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
2 |
+
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn, Tensor
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from typing import Optional, List
|
9 |
+
|
10 |
+
from modules.codeformer.vqgan_arch import *
|
11 |
+
from basicsr.utils import get_root_logger
|
12 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
13 |
+
|
14 |
+
def calc_mean_std(feat, eps=1e-5):
|
15 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
feat (Tensor): 4D tensor.
|
19 |
+
eps (float): A small value added to the variance to avoid
|
20 |
+
divide-by-zero. Default: 1e-5.
|
21 |
+
"""
|
22 |
+
size = feat.size()
|
23 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
24 |
+
b, c = size[:2]
|
25 |
+
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
26 |
+
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
27 |
+
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
28 |
+
return feat_mean, feat_std
|
29 |
+
|
30 |
+
|
31 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
32 |
+
"""Adaptive instance normalization.
|
33 |
+
|
34 |
+
Adjust the reference features to have the similar color and illuminations
|
35 |
+
as those in the degradate features.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
content_feat (Tensor): The reference feature.
|
39 |
+
style_feat (Tensor): The degradate features.
|
40 |
+
"""
|
41 |
+
size = content_feat.size()
|
42 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
43 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
44 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
45 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
46 |
+
|
47 |
+
|
48 |
+
class PositionEmbeddingSine(nn.Module):
|
49 |
+
"""
|
50 |
+
This is a more standard version of the position embedding, very similar to the one
|
51 |
+
used by the Attention is all you need paper, generalized to work on images.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
55 |
+
super().__init__()
|
56 |
+
self.num_pos_feats = num_pos_feats
|
57 |
+
self.temperature = temperature
|
58 |
+
self.normalize = normalize
|
59 |
+
if scale is not None and normalize is False:
|
60 |
+
raise ValueError("normalize should be True if scale is passed")
|
61 |
+
if scale is None:
|
62 |
+
scale = 2 * math.pi
|
63 |
+
self.scale = scale
|
64 |
+
|
65 |
+
def forward(self, x, mask=None):
|
66 |
+
if mask is None:
|
67 |
+
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
68 |
+
not_mask = ~mask
|
69 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
70 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
71 |
+
if self.normalize:
|
72 |
+
eps = 1e-6
|
73 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
74 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
75 |
+
|
76 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
77 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
78 |
+
|
79 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
80 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
81 |
+
pos_x = torch.stack(
|
82 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
83 |
+
).flatten(3)
|
84 |
+
pos_y = torch.stack(
|
85 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
86 |
+
).flatten(3)
|
87 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
88 |
+
return pos
|
89 |
+
|
90 |
+
def _get_activation_fn(activation):
|
91 |
+
"""Return an activation function given a string"""
|
92 |
+
if activation == "relu":
|
93 |
+
return F.relu
|
94 |
+
if activation == "gelu":
|
95 |
+
return F.gelu
|
96 |
+
if activation == "glu":
|
97 |
+
return F.glu
|
98 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
99 |
+
|
100 |
+
|
101 |
+
class TransformerSALayer(nn.Module):
|
102 |
+
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
103 |
+
super().__init__()
|
104 |
+
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
105 |
+
# Implementation of Feedforward model - MLP
|
106 |
+
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
107 |
+
self.dropout = nn.Dropout(dropout)
|
108 |
+
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
109 |
+
|
110 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
111 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
112 |
+
self.dropout1 = nn.Dropout(dropout)
|
113 |
+
self.dropout2 = nn.Dropout(dropout)
|
114 |
+
|
115 |
+
self.activation = _get_activation_fn(activation)
|
116 |
+
|
117 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
118 |
+
return tensor if pos is None else tensor + pos
|
119 |
+
|
120 |
+
def forward(self, tgt,
|
121 |
+
tgt_mask: Optional[Tensor] = None,
|
122 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
123 |
+
query_pos: Optional[Tensor] = None):
|
124 |
+
|
125 |
+
# self attention
|
126 |
+
tgt2 = self.norm1(tgt)
|
127 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
128 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
129 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
130 |
+
tgt = tgt + self.dropout1(tgt2)
|
131 |
+
|
132 |
+
# ffn
|
133 |
+
tgt2 = self.norm2(tgt)
|
134 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
135 |
+
tgt = tgt + self.dropout2(tgt2)
|
136 |
+
return tgt
|
137 |
+
|
138 |
+
class Fuse_sft_block(nn.Module):
|
139 |
+
def __init__(self, in_ch, out_ch):
|
140 |
+
super().__init__()
|
141 |
+
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
142 |
+
|
143 |
+
self.scale = nn.Sequential(
|
144 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
145 |
+
nn.LeakyReLU(0.2, True),
|
146 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
147 |
+
|
148 |
+
self.shift = nn.Sequential(
|
149 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
150 |
+
nn.LeakyReLU(0.2, True),
|
151 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
152 |
+
|
153 |
+
def forward(self, enc_feat, dec_feat, w=1):
|
154 |
+
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
155 |
+
scale = self.scale(enc_feat)
|
156 |
+
shift = self.shift(enc_feat)
|
157 |
+
residual = w * (dec_feat * scale + shift)
|
158 |
+
out = dec_feat + residual
|
159 |
+
return out
|
160 |
+
|
161 |
+
|
162 |
+
@ARCH_REGISTRY.register()
|
163 |
+
class CodeFormer(VQAutoEncoder):
|
164 |
+
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
165 |
+
codebook_size=1024, latent_size=256,
|
166 |
+
connect_list=['32', '64', '128', '256'],
|
167 |
+
fix_modules=['quantize','generator']):
|
168 |
+
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
169 |
+
|
170 |
+
if fix_modules is not None:
|
171 |
+
for module in fix_modules:
|
172 |
+
for param in getattr(self, module).parameters():
|
173 |
+
param.requires_grad = False
|
174 |
+
|
175 |
+
self.connect_list = connect_list
|
176 |
+
self.n_layers = n_layers
|
177 |
+
self.dim_embd = dim_embd
|
178 |
+
self.dim_mlp = dim_embd*2
|
179 |
+
|
180 |
+
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
181 |
+
self.feat_emb = nn.Linear(256, self.dim_embd)
|
182 |
+
|
183 |
+
# transformer
|
184 |
+
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
185 |
+
for _ in range(self.n_layers)])
|
186 |
+
|
187 |
+
# logits_predict head
|
188 |
+
self.idx_pred_layer = nn.Sequential(
|
189 |
+
nn.LayerNorm(dim_embd),
|
190 |
+
nn.Linear(dim_embd, codebook_size, bias=False))
|
191 |
+
|
192 |
+
self.channels = {
|
193 |
+
'16': 512,
|
194 |
+
'32': 256,
|
195 |
+
'64': 256,
|
196 |
+
'128': 128,
|
197 |
+
'256': 128,
|
198 |
+
'512': 64,
|
199 |
+
}
|
200 |
+
|
201 |
+
# after second residual block for > 16, before attn layer for ==16
|
202 |
+
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
203 |
+
# after first residual block for > 16, before attn layer for ==16
|
204 |
+
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
205 |
+
|
206 |
+
# fuse_convs_dict
|
207 |
+
self.fuse_convs_dict = nn.ModuleDict()
|
208 |
+
for f_size in self.connect_list:
|
209 |
+
in_ch = self.channels[f_size]
|
210 |
+
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
211 |
+
|
212 |
+
def _init_weights(self, module):
|
213 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
214 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
215 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
216 |
+
module.bias.data.zero_()
|
217 |
+
elif isinstance(module, nn.LayerNorm):
|
218 |
+
module.bias.data.zero_()
|
219 |
+
module.weight.data.fill_(1.0)
|
220 |
+
|
221 |
+
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
222 |
+
# ################### Encoder #####################
|
223 |
+
enc_feat_dict = {}
|
224 |
+
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
225 |
+
for i, block in enumerate(self.encoder.blocks):
|
226 |
+
x = block(x)
|
227 |
+
if i in out_list:
|
228 |
+
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
229 |
+
|
230 |
+
lq_feat = x
|
231 |
+
# ################# Transformer ###################
|
232 |
+
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
233 |
+
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
234 |
+
# BCHW -> BC(HW) -> (HW)BC
|
235 |
+
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
236 |
+
query_emb = feat_emb
|
237 |
+
# Transformer encoder
|
238 |
+
for layer in self.ft_layers:
|
239 |
+
query_emb = layer(query_emb, query_pos=pos_emb)
|
240 |
+
|
241 |
+
# output logits
|
242 |
+
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
243 |
+
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
244 |
+
|
245 |
+
if code_only: # for training stage II
|
246 |
+
# logits doesn't need softmax before cross_entropy loss
|
247 |
+
return logits, lq_feat
|
248 |
+
|
249 |
+
# ################# Quantization ###################
|
250 |
+
# if self.training:
|
251 |
+
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
252 |
+
# # b(hw)c -> bc(hw) -> bchw
|
253 |
+
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
254 |
+
# ------------
|
255 |
+
soft_one_hot = F.softmax(logits, dim=2)
|
256 |
+
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
257 |
+
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
258 |
+
# preserve gradients
|
259 |
+
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
260 |
+
|
261 |
+
if detach_16:
|
262 |
+
quant_feat = quant_feat.detach() # for training stage III
|
263 |
+
if adain:
|
264 |
+
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
265 |
+
|
266 |
+
# ################## Generator ####################
|
267 |
+
x = quant_feat
|
268 |
+
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
269 |
+
|
270 |
+
for i, block in enumerate(self.generator.blocks):
|
271 |
+
x = block(x)
|
272 |
+
if i in fuse_list: # fuse after i-th block
|
273 |
+
f_size = str(x.shape[-1])
|
274 |
+
if w>0:
|
275 |
+
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
276 |
+
out = x
|
277 |
+
# logits doesn't need softmax before cross_entropy loss
|
278 |
+
return out, logits, lq_feat
|
modules/codeformer/vqgan_arch.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
2 |
+
|
3 |
+
'''
|
4 |
+
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
5 |
+
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
6 |
+
|
7 |
+
'''
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import copy
|
13 |
+
from basicsr.utils import get_root_logger
|
14 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
15 |
+
|
16 |
+
def normalize(in_channels):
|
17 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
18 |
+
|
19 |
+
|
20 |
+
@torch.jit.script
|
21 |
+
def swish(x):
|
22 |
+
return x*torch.sigmoid(x)
|
23 |
+
|
24 |
+
|
25 |
+
# Define VQVAE classes
|
26 |
+
class VectorQuantizer(nn.Module):
|
27 |
+
def __init__(self, codebook_size, emb_dim, beta):
|
28 |
+
super(VectorQuantizer, self).__init__()
|
29 |
+
self.codebook_size = codebook_size # number of embeddings
|
30 |
+
self.emb_dim = emb_dim # dimension of embedding
|
31 |
+
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
32 |
+
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
33 |
+
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
34 |
+
|
35 |
+
def forward(self, z):
|
36 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
37 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
38 |
+
z_flattened = z.view(-1, self.emb_dim)
|
39 |
+
|
40 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
41 |
+
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
42 |
+
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
43 |
+
|
44 |
+
mean_distance = torch.mean(d)
|
45 |
+
# find closest encodings
|
46 |
+
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
47 |
+
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
48 |
+
# [0-1], higher score, higher confidence
|
49 |
+
min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
50 |
+
|
51 |
+
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
52 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
53 |
+
|
54 |
+
# get quantized latent vectors
|
55 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
56 |
+
# compute loss for embedding
|
57 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
58 |
+
# preserve gradients
|
59 |
+
z_q = z + (z_q - z).detach()
|
60 |
+
|
61 |
+
# perplexity
|
62 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
63 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
64 |
+
# reshape back to match original input shape
|
65 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
66 |
+
|
67 |
+
return z_q, loss, {
|
68 |
+
"perplexity": perplexity,
|
69 |
+
"min_encodings": min_encodings,
|
70 |
+
"min_encoding_indices": min_encoding_indices,
|
71 |
+
"min_encoding_scores": min_encoding_scores,
|
72 |
+
"mean_distance": mean_distance
|
73 |
+
}
|
74 |
+
|
75 |
+
def get_codebook_feat(self, indices, shape):
|
76 |
+
# input indices: batch*token_num -> (batch*token_num)*1
|
77 |
+
# shape: batch, height, width, channel
|
78 |
+
indices = indices.view(-1,1)
|
79 |
+
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
80 |
+
min_encodings.scatter_(1, indices, 1)
|
81 |
+
# get quantized latent vectors
|
82 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
83 |
+
|
84 |
+
if shape is not None: # reshape back to match original input shape
|
85 |
+
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
86 |
+
|
87 |
+
return z_q
|
88 |
+
|
89 |
+
|
90 |
+
class GumbelQuantizer(nn.Module):
|
91 |
+
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
92 |
+
super().__init__()
|
93 |
+
self.codebook_size = codebook_size # number of embeddings
|
94 |
+
self.emb_dim = emb_dim # dimension of embedding
|
95 |
+
self.straight_through = straight_through
|
96 |
+
self.temperature = temp_init
|
97 |
+
self.kl_weight = kl_weight
|
98 |
+
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
99 |
+
self.embed = nn.Embedding(codebook_size, emb_dim)
|
100 |
+
|
101 |
+
def forward(self, z):
|
102 |
+
hard = self.straight_through if self.training else True
|
103 |
+
|
104 |
+
logits = self.proj(z)
|
105 |
+
|
106 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
107 |
+
|
108 |
+
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
109 |
+
|
110 |
+
# + kl divergence to the prior loss
|
111 |
+
qy = F.softmax(logits, dim=1)
|
112 |
+
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
113 |
+
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
114 |
+
|
115 |
+
return z_q, diff, {
|
116 |
+
"min_encoding_indices": min_encoding_indices
|
117 |
+
}
|
118 |
+
|
119 |
+
|
120 |
+
class Downsample(nn.Module):
|
121 |
+
def __init__(self, in_channels):
|
122 |
+
super().__init__()
|
123 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
pad = (0, 1, 0, 1)
|
127 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
128 |
+
x = self.conv(x)
|
129 |
+
return x
|
130 |
+
|
131 |
+
|
132 |
+
class Upsample(nn.Module):
|
133 |
+
def __init__(self, in_channels):
|
134 |
+
super().__init__()
|
135 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
139 |
+
x = self.conv(x)
|
140 |
+
|
141 |
+
return x
|
142 |
+
|
143 |
+
|
144 |
+
class ResBlock(nn.Module):
|
145 |
+
def __init__(self, in_channels, out_channels=None):
|
146 |
+
super(ResBlock, self).__init__()
|
147 |
+
self.in_channels = in_channels
|
148 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
149 |
+
self.norm1 = normalize(in_channels)
|
150 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
151 |
+
self.norm2 = normalize(out_channels)
|
152 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
153 |
+
if self.in_channels != self.out_channels:
|
154 |
+
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
155 |
+
|
156 |
+
def forward(self, x_in):
|
157 |
+
x = x_in
|
158 |
+
x = self.norm1(x)
|
159 |
+
x = swish(x)
|
160 |
+
x = self.conv1(x)
|
161 |
+
x = self.norm2(x)
|
162 |
+
x = swish(x)
|
163 |
+
x = self.conv2(x)
|
164 |
+
if self.in_channels != self.out_channels:
|
165 |
+
x_in = self.conv_out(x_in)
|
166 |
+
|
167 |
+
return x + x_in
|
168 |
+
|
169 |
+
|
170 |
+
class AttnBlock(nn.Module):
|
171 |
+
def __init__(self, in_channels):
|
172 |
+
super().__init__()
|
173 |
+
self.in_channels = in_channels
|
174 |
+
|
175 |
+
self.norm = normalize(in_channels)
|
176 |
+
self.q = torch.nn.Conv2d(
|
177 |
+
in_channels,
|
178 |
+
in_channels,
|
179 |
+
kernel_size=1,
|
180 |
+
stride=1,
|
181 |
+
padding=0
|
182 |
+
)
|
183 |
+
self.k = torch.nn.Conv2d(
|
184 |
+
in_channels,
|
185 |
+
in_channels,
|
186 |
+
kernel_size=1,
|
187 |
+
stride=1,
|
188 |
+
padding=0
|
189 |
+
)
|
190 |
+
self.v = torch.nn.Conv2d(
|
191 |
+
in_channels,
|
192 |
+
in_channels,
|
193 |
+
kernel_size=1,
|
194 |
+
stride=1,
|
195 |
+
padding=0
|
196 |
+
)
|
197 |
+
self.proj_out = torch.nn.Conv2d(
|
198 |
+
in_channels,
|
199 |
+
in_channels,
|
200 |
+
kernel_size=1,
|
201 |
+
stride=1,
|
202 |
+
padding=0
|
203 |
+
)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
h_ = x
|
207 |
+
h_ = self.norm(h_)
|
208 |
+
q = self.q(h_)
|
209 |
+
k = self.k(h_)
|
210 |
+
v = self.v(h_)
|
211 |
+
|
212 |
+
# compute attention
|
213 |
+
b, c, h, w = q.shape
|
214 |
+
q = q.reshape(b, c, h*w)
|
215 |
+
q = q.permute(0, 2, 1)
|
216 |
+
k = k.reshape(b, c, h*w)
|
217 |
+
w_ = torch.bmm(q, k)
|
218 |
+
w_ = w_ * (int(c)**(-0.5))
|
219 |
+
w_ = F.softmax(w_, dim=2)
|
220 |
+
|
221 |
+
# attend to values
|
222 |
+
v = v.reshape(b, c, h*w)
|
223 |
+
w_ = w_.permute(0, 2, 1)
|
224 |
+
h_ = torch.bmm(v, w_)
|
225 |
+
h_ = h_.reshape(b, c, h, w)
|
226 |
+
|
227 |
+
h_ = self.proj_out(h_)
|
228 |
+
|
229 |
+
return x+h_
|
230 |
+
|
231 |
+
|
232 |
+
class Encoder(nn.Module):
|
233 |
+
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
234 |
+
super().__init__()
|
235 |
+
self.nf = nf
|
236 |
+
self.num_resolutions = len(ch_mult)
|
237 |
+
self.num_res_blocks = num_res_blocks
|
238 |
+
self.resolution = resolution
|
239 |
+
self.attn_resolutions = attn_resolutions
|
240 |
+
|
241 |
+
curr_res = self.resolution
|
242 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
243 |
+
|
244 |
+
blocks = []
|
245 |
+
# initial convultion
|
246 |
+
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
247 |
+
|
248 |
+
# residual and downsampling blocks, with attention on smaller res (16x16)
|
249 |
+
for i in range(self.num_resolutions):
|
250 |
+
block_in_ch = nf * in_ch_mult[i]
|
251 |
+
block_out_ch = nf * ch_mult[i]
|
252 |
+
for _ in range(self.num_res_blocks):
|
253 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
254 |
+
block_in_ch = block_out_ch
|
255 |
+
if curr_res in attn_resolutions:
|
256 |
+
blocks.append(AttnBlock(block_in_ch))
|
257 |
+
|
258 |
+
if i != self.num_resolutions - 1:
|
259 |
+
blocks.append(Downsample(block_in_ch))
|
260 |
+
curr_res = curr_res // 2
|
261 |
+
|
262 |
+
# non-local attention block
|
263 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
264 |
+
blocks.append(AttnBlock(block_in_ch))
|
265 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
266 |
+
|
267 |
+
# normalise and convert to latent size
|
268 |
+
blocks.append(normalize(block_in_ch))
|
269 |
+
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
270 |
+
self.blocks = nn.ModuleList(blocks)
|
271 |
+
|
272 |
+
def forward(self, x):
|
273 |
+
for block in self.blocks:
|
274 |
+
x = block(x)
|
275 |
+
|
276 |
+
return x
|
277 |
+
|
278 |
+
|
279 |
+
class Generator(nn.Module):
|
280 |
+
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
281 |
+
super().__init__()
|
282 |
+
self.nf = nf
|
283 |
+
self.ch_mult = ch_mult
|
284 |
+
self.num_resolutions = len(self.ch_mult)
|
285 |
+
self.num_res_blocks = res_blocks
|
286 |
+
self.resolution = img_size
|
287 |
+
self.attn_resolutions = attn_resolutions
|
288 |
+
self.in_channels = emb_dim
|
289 |
+
self.out_channels = 3
|
290 |
+
block_in_ch = self.nf * self.ch_mult[-1]
|
291 |
+
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
292 |
+
|
293 |
+
blocks = []
|
294 |
+
# initial conv
|
295 |
+
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
296 |
+
|
297 |
+
# non-local attention block
|
298 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
299 |
+
blocks.append(AttnBlock(block_in_ch))
|
300 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
301 |
+
|
302 |
+
for i in reversed(range(self.num_resolutions)):
|
303 |
+
block_out_ch = self.nf * self.ch_mult[i]
|
304 |
+
|
305 |
+
for _ in range(self.num_res_blocks):
|
306 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
307 |
+
block_in_ch = block_out_ch
|
308 |
+
|
309 |
+
if curr_res in self.attn_resolutions:
|
310 |
+
blocks.append(AttnBlock(block_in_ch))
|
311 |
+
|
312 |
+
if i != 0:
|
313 |
+
blocks.append(Upsample(block_in_ch))
|
314 |
+
curr_res = curr_res * 2
|
315 |
+
|
316 |
+
blocks.append(normalize(block_in_ch))
|
317 |
+
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
318 |
+
|
319 |
+
self.blocks = nn.ModuleList(blocks)
|
320 |
+
|
321 |
+
|
322 |
+
def forward(self, x):
|
323 |
+
for block in self.blocks:
|
324 |
+
x = block(x)
|
325 |
+
|
326 |
+
return x
|
327 |
+
|
328 |
+
|
329 |
+
@ARCH_REGISTRY.register()
|
330 |
+
class VQAutoEncoder(nn.Module):
|
331 |
+
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
332 |
+
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
333 |
+
super().__init__()
|
334 |
+
logger = get_root_logger()
|
335 |
+
self.in_channels = 3
|
336 |
+
self.nf = nf
|
337 |
+
self.n_blocks = res_blocks
|
338 |
+
self.codebook_size = codebook_size
|
339 |
+
self.embed_dim = emb_dim
|
340 |
+
self.ch_mult = ch_mult
|
341 |
+
self.resolution = img_size
|
342 |
+
self.attn_resolutions = attn_resolutions
|
343 |
+
self.quantizer_type = quantizer
|
344 |
+
self.encoder = Encoder(
|
345 |
+
self.in_channels,
|
346 |
+
self.nf,
|
347 |
+
self.embed_dim,
|
348 |
+
self.ch_mult,
|
349 |
+
self.n_blocks,
|
350 |
+
self.resolution,
|
351 |
+
self.attn_resolutions
|
352 |
+
)
|
353 |
+
if self.quantizer_type == "nearest":
|
354 |
+
self.beta = beta #0.25
|
355 |
+
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
356 |
+
elif self.quantizer_type == "gumbel":
|
357 |
+
self.gumbel_num_hiddens = emb_dim
|
358 |
+
self.straight_through = gumbel_straight_through
|
359 |
+
self.kl_weight = gumbel_kl_weight
|
360 |
+
self.quantize = GumbelQuantizer(
|
361 |
+
self.codebook_size,
|
362 |
+
self.embed_dim,
|
363 |
+
self.gumbel_num_hiddens,
|
364 |
+
self.straight_through,
|
365 |
+
self.kl_weight
|
366 |
+
)
|
367 |
+
self.generator = Generator(
|
368 |
+
self.nf,
|
369 |
+
self.embed_dim,
|
370 |
+
self.ch_mult,
|
371 |
+
self.n_blocks,
|
372 |
+
self.resolution,
|
373 |
+
self.attn_resolutions
|
374 |
+
)
|
375 |
+
|
376 |
+
if model_path is not None:
|
377 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
378 |
+
if 'params_ema' in chkpt:
|
379 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
380 |
+
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
381 |
+
elif 'params' in chkpt:
|
382 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
383 |
+
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
384 |
+
else:
|
385 |
+
raise ValueError(f'Wrong params!')
|
386 |
+
|
387 |
+
|
388 |
+
def forward(self, x):
|
389 |
+
x = self.encoder(x)
|
390 |
+
quant, codebook_loss, quant_stats = self.quantize(x)
|
391 |
+
x = self.generator(quant)
|
392 |
+
return x, codebook_loss, quant_stats
|
393 |
+
|
394 |
+
|
395 |
+
|
396 |
+
# patch based discriminator
|
397 |
+
@ARCH_REGISTRY.register()
|
398 |
+
class VQGANDiscriminator(nn.Module):
|
399 |
+
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
400 |
+
super().__init__()
|
401 |
+
|
402 |
+
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
403 |
+
ndf_mult = 1
|
404 |
+
ndf_mult_prev = 1
|
405 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
406 |
+
ndf_mult_prev = ndf_mult
|
407 |
+
ndf_mult = min(2 ** n, 8)
|
408 |
+
layers += [
|
409 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
410 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
411 |
+
nn.LeakyReLU(0.2, True)
|
412 |
+
]
|
413 |
+
|
414 |
+
ndf_mult_prev = ndf_mult
|
415 |
+
ndf_mult = min(2 ** n_layers, 8)
|
416 |
+
|
417 |
+
layers += [
|
418 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
419 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
420 |
+
nn.LeakyReLU(0.2, True)
|
421 |
+
]
|
422 |
+
|
423 |
+
layers += [
|
424 |
+
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
425 |
+
self.main = nn.Sequential(*layers)
|
426 |
+
|
427 |
+
if model_path is not None:
|
428 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
429 |
+
if 'params_d' in chkpt:
|
430 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
431 |
+
elif 'params' in chkpt:
|
432 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
433 |
+
else:
|
434 |
+
raise ValueError(f'Wrong params!')
|
435 |
+
|
436 |
+
def forward(self, x):
|
437 |
+
return self.main(x)
|
modules/codeformer_model.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import modules.face_restoration
|
9 |
+
import modules.shared
|
10 |
+
from modules import shared, devices, modelloader
|
11 |
+
from modules.paths import script_path, models_path
|
12 |
+
|
13 |
+
# codeformer people made a choice to include modified basicsr library to their project which makes
|
14 |
+
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
15 |
+
# I am making a choice to include some files from codeformer to work around this issue.
|
16 |
+
model_dir = "Codeformer"
|
17 |
+
model_path = os.path.join(models_path, model_dir)
|
18 |
+
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
19 |
+
|
20 |
+
have_codeformer = False
|
21 |
+
codeformer = None
|
22 |
+
|
23 |
+
|
24 |
+
def setup_model(dirname):
|
25 |
+
global model_path
|
26 |
+
if not os.path.exists(model_path):
|
27 |
+
os.makedirs(model_path)
|
28 |
+
|
29 |
+
path = modules.paths.paths.get("CodeFormer", None)
|
30 |
+
if path is None:
|
31 |
+
return
|
32 |
+
|
33 |
+
try:
|
34 |
+
from torchvision.transforms.functional import normalize
|
35 |
+
from modules.codeformer.codeformer_arch import CodeFormer
|
36 |
+
from basicsr.utils.download_util import load_file_from_url
|
37 |
+
from basicsr.utils import imwrite, img2tensor, tensor2img
|
38 |
+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
39 |
+
from modules.shared import cmd_opts
|
40 |
+
|
41 |
+
net_class = CodeFormer
|
42 |
+
|
43 |
+
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
|
44 |
+
def name(self):
|
45 |
+
return "CodeFormer"
|
46 |
+
|
47 |
+
def __init__(self, dirname):
|
48 |
+
self.net = None
|
49 |
+
self.face_helper = None
|
50 |
+
self.cmd_dir = dirname
|
51 |
+
|
52 |
+
def create_models(self):
|
53 |
+
|
54 |
+
if self.net is not None and self.face_helper is not None:
|
55 |
+
self.net.to(devices.device_codeformer)
|
56 |
+
return self.net, self.face_helper
|
57 |
+
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
|
58 |
+
if len(model_paths) != 0:
|
59 |
+
ckpt_path = model_paths[0]
|
60 |
+
else:
|
61 |
+
print("Unable to load codeformer model.")
|
62 |
+
return None, None
|
63 |
+
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
64 |
+
checkpoint = torch.load(ckpt_path)['params_ema']
|
65 |
+
net.load_state_dict(checkpoint)
|
66 |
+
net.eval()
|
67 |
+
|
68 |
+
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
|
69 |
+
|
70 |
+
self.net = net
|
71 |
+
self.face_helper = face_helper
|
72 |
+
|
73 |
+
return net, face_helper
|
74 |
+
|
75 |
+
def send_model_to(self, device):
|
76 |
+
self.net.to(device)
|
77 |
+
self.face_helper.face_det.to(device)
|
78 |
+
self.face_helper.face_parse.to(device)
|
79 |
+
|
80 |
+
def restore(self, np_image, w=None):
|
81 |
+
np_image = np_image[:, :, ::-1]
|
82 |
+
|
83 |
+
original_resolution = np_image.shape[0:2]
|
84 |
+
|
85 |
+
self.create_models()
|
86 |
+
if self.net is None or self.face_helper is None:
|
87 |
+
return np_image
|
88 |
+
|
89 |
+
self.send_model_to(devices.device_codeformer)
|
90 |
+
|
91 |
+
self.face_helper.clean_all()
|
92 |
+
self.face_helper.read_image(np_image)
|
93 |
+
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
94 |
+
self.face_helper.align_warp_face()
|
95 |
+
|
96 |
+
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
97 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
98 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
99 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
100 |
+
|
101 |
+
try:
|
102 |
+
with torch.no_grad():
|
103 |
+
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
104 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
105 |
+
del output
|
106 |
+
torch.cuda.empty_cache()
|
107 |
+
except Exception as error:
|
108 |
+
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
|
109 |
+
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
110 |
+
|
111 |
+
restored_face = restored_face.astype('uint8')
|
112 |
+
self.face_helper.add_restored_face(restored_face)
|
113 |
+
|
114 |
+
self.face_helper.get_inverse_affine(None)
|
115 |
+
|
116 |
+
restored_img = self.face_helper.paste_faces_to_input_image()
|
117 |
+
restored_img = restored_img[:, :, ::-1]
|
118 |
+
|
119 |
+
if original_resolution != restored_img.shape[0:2]:
|
120 |
+
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
121 |
+
|
122 |
+
self.face_helper.clean_all()
|
123 |
+
|
124 |
+
if shared.opts.face_restoration_unload:
|
125 |
+
self.send_model_to(devices.cpu)
|
126 |
+
|
127 |
+
return restored_img
|
128 |
+
|
129 |
+
global have_codeformer
|
130 |
+
have_codeformer = True
|
131 |
+
|
132 |
+
global codeformer
|
133 |
+
codeformer = FaceRestorerCodeFormer(dirname)
|
134 |
+
shared.face_restorers.append(codeformer)
|
135 |
+
|
136 |
+
except Exception:
|
137 |
+
print("Error setting up CodeFormer:", file=sys.stderr)
|
138 |
+
print(traceback.format_exc(), file=sys.stderr)
|
139 |
+
|
140 |
+
# sys.path = stored_sys_path
|
modules/deepbooru.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from concurrent.futures import ProcessPoolExecutor
|
3 |
+
import multiprocessing
|
4 |
+
import time
|
5 |
+
import re
|
6 |
+
|
7 |
+
re_special = re.compile(r'([\\()])')
|
8 |
+
|
9 |
+
def get_deepbooru_tags(pil_image):
|
10 |
+
"""
|
11 |
+
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
|
12 |
+
"""
|
13 |
+
from modules import shared # prevents circular reference
|
14 |
+
|
15 |
+
try:
|
16 |
+
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
|
17 |
+
return get_tags_from_process(pil_image)
|
18 |
+
finally:
|
19 |
+
release_process()
|
20 |
+
|
21 |
+
|
22 |
+
OPT_INCLUDE_RANKS = "include_ranks"
|
23 |
+
def create_deepbooru_opts():
|
24 |
+
from modules import shared
|
25 |
+
|
26 |
+
return {
|
27 |
+
"use_spaces": shared.opts.deepbooru_use_spaces,
|
28 |
+
"use_escape": shared.opts.deepbooru_escape,
|
29 |
+
"alpha_sort": shared.opts.deepbooru_sort_alpha,
|
30 |
+
OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
|
35 |
+
model, tags = get_deepbooru_tags_model()
|
36 |
+
while True: # while process is running, keep monitoring queue for new image
|
37 |
+
pil_image = queue.get()
|
38 |
+
if pil_image == "QUIT":
|
39 |
+
break
|
40 |
+
else:
|
41 |
+
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
|
42 |
+
|
43 |
+
|
44 |
+
def create_deepbooru_process(threshold, deepbooru_opts):
|
45 |
+
"""
|
46 |
+
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
|
47 |
+
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
|
48 |
+
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
|
49 |
+
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
|
50 |
+
the tags.
|
51 |
+
"""
|
52 |
+
from modules import shared # prevents circular reference
|
53 |
+
shared.deepbooru_process_manager = multiprocessing.Manager()
|
54 |
+
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
|
55 |
+
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
|
56 |
+
shared.deepbooru_process_return["value"] = -1
|
57 |
+
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
|
58 |
+
shared.deepbooru_process.start()
|
59 |
+
|
60 |
+
|
61 |
+
def get_tags_from_process(image):
|
62 |
+
from modules import shared
|
63 |
+
|
64 |
+
shared.deepbooru_process_return["value"] = -1
|
65 |
+
shared.deepbooru_process_queue.put(image)
|
66 |
+
while shared.deepbooru_process_return["value"] == -1:
|
67 |
+
time.sleep(0.2)
|
68 |
+
caption = shared.deepbooru_process_return["value"]
|
69 |
+
shared.deepbooru_process_return["value"] = -1
|
70 |
+
|
71 |
+
return caption
|
72 |
+
|
73 |
+
|
74 |
+
def release_process():
|
75 |
+
"""
|
76 |
+
Stops the deepbooru process to return used memory
|
77 |
+
"""
|
78 |
+
from modules import shared # prevents circular reference
|
79 |
+
shared.deepbooru_process_queue.put("QUIT")
|
80 |
+
shared.deepbooru_process.join()
|
81 |
+
shared.deepbooru_process_queue = None
|
82 |
+
shared.deepbooru_process = None
|
83 |
+
shared.deepbooru_process_return = None
|
84 |
+
shared.deepbooru_process_manager = None
|
85 |
+
|
86 |
+
def get_deepbooru_tags_model():
|
87 |
+
import deepdanbooru as dd
|
88 |
+
import tensorflow as tf
|
89 |
+
import numpy as np
|
90 |
+
this_folder = os.path.dirname(__file__)
|
91 |
+
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
|
92 |
+
if not os.path.exists(os.path.join(model_path, 'project.json')):
|
93 |
+
# there is no point importing these every time
|
94 |
+
import zipfile
|
95 |
+
from basicsr.utils.download_util import load_file_from_url
|
96 |
+
load_file_from_url(
|
97 |
+
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
|
98 |
+
model_path)
|
99 |
+
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
|
100 |
+
zip_ref.extractall(model_path)
|
101 |
+
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
|
102 |
+
|
103 |
+
tags = dd.project.load_tags_from_project(model_path)
|
104 |
+
model = dd.project.load_model_from_project(
|
105 |
+
model_path, compile_model=True
|
106 |
+
)
|
107 |
+
return model, tags
|
108 |
+
|
109 |
+
|
110 |
+
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
|
111 |
+
import deepdanbooru as dd
|
112 |
+
import tensorflow as tf
|
113 |
+
import numpy as np
|
114 |
+
|
115 |
+
alpha_sort = deepbooru_opts['alpha_sort']
|
116 |
+
use_spaces = deepbooru_opts['use_spaces']
|
117 |
+
use_escape = deepbooru_opts['use_escape']
|
118 |
+
include_ranks = deepbooru_opts['include_ranks']
|
119 |
+
|
120 |
+
width = model.input_shape[2]
|
121 |
+
height = model.input_shape[1]
|
122 |
+
image = np.array(pil_image)
|
123 |
+
image = tf.image.resize(
|
124 |
+
image,
|
125 |
+
size=(height, width),
|
126 |
+
method=tf.image.ResizeMethod.AREA,
|
127 |
+
preserve_aspect_ratio=True,
|
128 |
+
)
|
129 |
+
image = image.numpy() # EagerTensor to np.array
|
130 |
+
image = dd.image.transform_and_pad_image(image, width, height)
|
131 |
+
image = image / 255.0
|
132 |
+
image_shape = image.shape
|
133 |
+
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
|
134 |
+
|
135 |
+
y = model.predict(image)[0]
|
136 |
+
|
137 |
+
result_dict = {}
|
138 |
+
|
139 |
+
for i, tag in enumerate(tags):
|
140 |
+
result_dict[tag] = y[i]
|
141 |
+
|
142 |
+
unsorted_tags_in_theshold = []
|
143 |
+
result_tags_print = []
|
144 |
+
for tag in tags:
|
145 |
+
if result_dict[tag] >= threshold:
|
146 |
+
if tag.startswith("rating:"):
|
147 |
+
continue
|
148 |
+
unsorted_tags_in_theshold.append((result_dict[tag], tag))
|
149 |
+
result_tags_print.append(f'{result_dict[tag]} {tag}')
|
150 |
+
|
151 |
+
# sort tags
|
152 |
+
result_tags_out = []
|
153 |
+
sort_ndx = 0
|
154 |
+
if alpha_sort:
|
155 |
+
sort_ndx = 1
|
156 |
+
|
157 |
+
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
|
158 |
+
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
|
159 |
+
for weight, tag in unsorted_tags_in_theshold:
|
160 |
+
# note: tag_outformat will still have a colon if include_ranks is True
|
161 |
+
tag_outformat = tag.replace(':', ' ')
|
162 |
+
if use_spaces:
|
163 |
+
tag_outformat = tag_outformat.replace('_', ' ')
|
164 |
+
if use_escape:
|
165 |
+
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
166 |
+
if include_ranks:
|
167 |
+
tag_outformat = f"({tag_outformat}:{weight:.3f})"
|
168 |
+
|
169 |
+
result_tags_out.append(tag_outformat)
|
170 |
+
|
171 |
+
print('\n'.join(sorted(result_tags_print, reverse=True)))
|
172 |
+
|
173 |
+
return ', '.join(result_tags_out)
|
modules/devices.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from modules import errors
|
6 |
+
|
7 |
+
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
8 |
+
has_mps = getattr(torch, 'has_mps', False)
|
9 |
+
|
10 |
+
cpu = torch.device("cpu")
|
11 |
+
|
12 |
+
|
13 |
+
def get_optimal_device():
|
14 |
+
if torch.cuda.is_available():
|
15 |
+
return torch.device("cuda")
|
16 |
+
|
17 |
+
if has_mps:
|
18 |
+
return torch.device("mps")
|
19 |
+
|
20 |
+
return cpu
|
21 |
+
|
22 |
+
|
23 |
+
def torch_gc():
|
24 |
+
if torch.cuda.is_available():
|
25 |
+
torch.cuda.empty_cache()
|
26 |
+
torch.cuda.ipc_collect()
|
27 |
+
|
28 |
+
|
29 |
+
def enable_tf32():
|
30 |
+
if torch.cuda.is_available():
|
31 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
32 |
+
torch.backends.cudnn.allow_tf32 = True
|
33 |
+
|
34 |
+
|
35 |
+
errors.run(enable_tf32, "Enabling TF32")
|
36 |
+
|
37 |
+
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
38 |
+
dtype = torch.float16
|
39 |
+
dtype_vae = torch.float16
|
40 |
+
|
41 |
+
def randn(seed, shape):
|
42 |
+
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
43 |
+
if device.type == 'mps':
|
44 |
+
generator = torch.Generator(device=cpu)
|
45 |
+
generator.manual_seed(seed)
|
46 |
+
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
47 |
+
return noise
|
48 |
+
|
49 |
+
torch.manual_seed(seed)
|
50 |
+
return torch.randn(shape, device=device)
|
51 |
+
|
52 |
+
|
53 |
+
def randn_without_seed(shape):
|
54 |
+
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
55 |
+
if device.type == 'mps':
|
56 |
+
generator = torch.Generator(device=cpu)
|
57 |
+
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
58 |
+
return noise
|
59 |
+
|
60 |
+
return torch.randn(shape, device=device)
|
61 |
+
|
62 |
+
|
63 |
+
def autocast(disable=False):
|
64 |
+
from modules import shared
|
65 |
+
|
66 |
+
if disable:
|
67 |
+
return contextlib.nullcontext()
|
68 |
+
|
69 |
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
70 |
+
return contextlib.nullcontext()
|
71 |
+
|
72 |
+
return torch.autocast("cuda")
|
modules/errors.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import traceback
|
3 |
+
|
4 |
+
|
5 |
+
def run(code, task):
|
6 |
+
try:
|
7 |
+
code()
|
8 |
+
except Exception as e:
|
9 |
+
print(f"{task}: {type(e).__name__}", file=sys.stderr)
|
10 |
+
print(traceback.format_exc(), file=sys.stderr)
|
modules/esrgan_model.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from basicsr.utils.download_util import load_file_from_url
|
7 |
+
|
8 |
+
import modules.esrgan_model_arch as arch
|
9 |
+
from modules import shared, modelloader, images, devices
|
10 |
+
from modules.upscaler import Upscaler, UpscalerData
|
11 |
+
from modules.shared import opts
|
12 |
+
|
13 |
+
|
14 |
+
def fix_model_layers(crt_model, pretrained_net):
|
15 |
+
# this code is adapted from https://github.com/xinntao/ESRGAN
|
16 |
+
if 'conv_first.weight' in pretrained_net:
|
17 |
+
return pretrained_net
|
18 |
+
|
19 |
+
if 'model.0.weight' not in pretrained_net:
|
20 |
+
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
|
21 |
+
if is_realesrgan:
|
22 |
+
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
|
23 |
+
else:
|
24 |
+
raise Exception("The file is not a ESRGAN model.")
|
25 |
+
|
26 |
+
crt_net = crt_model.state_dict()
|
27 |
+
load_net_clean = {}
|
28 |
+
for k, v in pretrained_net.items():
|
29 |
+
if k.startswith('module.'):
|
30 |
+
load_net_clean[k[7:]] = v
|
31 |
+
else:
|
32 |
+
load_net_clean[k] = v
|
33 |
+
pretrained_net = load_net_clean
|
34 |
+
|
35 |
+
tbd = []
|
36 |
+
for k, v in crt_net.items():
|
37 |
+
tbd.append(k)
|
38 |
+
|
39 |
+
# directly copy
|
40 |
+
for k, v in crt_net.items():
|
41 |
+
if k in pretrained_net and pretrained_net[k].size() == v.size():
|
42 |
+
crt_net[k] = pretrained_net[k]
|
43 |
+
tbd.remove(k)
|
44 |
+
|
45 |
+
crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
|
46 |
+
crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
|
47 |
+
|
48 |
+
for k in tbd.copy():
|
49 |
+
if 'RDB' in k:
|
50 |
+
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
51 |
+
if '.weight' in k:
|
52 |
+
ori_k = ori_k.replace('.weight', '.0.weight')
|
53 |
+
elif '.bias' in k:
|
54 |
+
ori_k = ori_k.replace('.bias', '.0.bias')
|
55 |
+
crt_net[k] = pretrained_net[ori_k]
|
56 |
+
tbd.remove(k)
|
57 |
+
|
58 |
+
crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
|
59 |
+
crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
|
60 |
+
crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
|
61 |
+
crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
|
62 |
+
crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
|
63 |
+
crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
|
64 |
+
crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
|
65 |
+
crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
|
66 |
+
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
|
67 |
+
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
|
68 |
+
|
69 |
+
return crt_net
|
70 |
+
|
71 |
+
class UpscalerESRGAN(Upscaler):
|
72 |
+
def __init__(self, dirname):
|
73 |
+
self.name = "ESRGAN"
|
74 |
+
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
|
75 |
+
self.model_name = "ESRGAN_4x"
|
76 |
+
self.scalers = []
|
77 |
+
self.user_path = dirname
|
78 |
+
super().__init__()
|
79 |
+
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
80 |
+
scalers = []
|
81 |
+
if len(model_paths) == 0:
|
82 |
+
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
83 |
+
scalers.append(scaler_data)
|
84 |
+
for file in model_paths:
|
85 |
+
if "http" in file:
|
86 |
+
name = self.model_name
|
87 |
+
else:
|
88 |
+
name = modelloader.friendly_name(file)
|
89 |
+
|
90 |
+
scaler_data = UpscalerData(name, file, self, 4)
|
91 |
+
self.scalers.append(scaler_data)
|
92 |
+
|
93 |
+
def do_upscale(self, img, selected_model):
|
94 |
+
model = self.load_model(selected_model)
|
95 |
+
if model is None:
|
96 |
+
return img
|
97 |
+
model.to(devices.device_esrgan)
|
98 |
+
img = esrgan_upscale(model, img)
|
99 |
+
return img
|
100 |
+
|
101 |
+
def load_model(self, path: str):
|
102 |
+
if "http" in path:
|
103 |
+
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
104 |
+
file_name="%s.pth" % self.model_name,
|
105 |
+
progress=True)
|
106 |
+
else:
|
107 |
+
filename = path
|
108 |
+
if not os.path.exists(filename) or filename is None:
|
109 |
+
print("Unable to load %s from %s" % (self.model_path, filename))
|
110 |
+
return None
|
111 |
+
|
112 |
+
pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
113 |
+
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
114 |
+
|
115 |
+
pretrained_net = fix_model_layers(crt_model, pretrained_net)
|
116 |
+
crt_model.load_state_dict(pretrained_net)
|
117 |
+
crt_model.eval()
|
118 |
+
|
119 |
+
return crt_model
|
120 |
+
|
121 |
+
|
122 |
+
def upscale_without_tiling(model, img):
|
123 |
+
img = np.array(img)
|
124 |
+
img = img[:, :, ::-1]
|
125 |
+
img = np.moveaxis(img, 2, 0) / 255
|
126 |
+
img = torch.from_numpy(img).float()
|
127 |
+
img = img.unsqueeze(0).to(devices.device_esrgan)
|
128 |
+
with torch.no_grad():
|
129 |
+
output = model(img)
|
130 |
+
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
131 |
+
output = 255. * np.moveaxis(output, 0, 2)
|
132 |
+
output = output.astype(np.uint8)
|
133 |
+
output = output[:, :, ::-1]
|
134 |
+
return Image.fromarray(output, 'RGB')
|
135 |
+
|
136 |
+
|
137 |
+
def esrgan_upscale(model, img):
|
138 |
+
if opts.ESRGAN_tile == 0:
|
139 |
+
return upscale_without_tiling(model, img)
|
140 |
+
|
141 |
+
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
142 |
+
newtiles = []
|
143 |
+
scale_factor = 1
|
144 |
+
|
145 |
+
for y, h, row in grid.tiles:
|
146 |
+
newrow = []
|
147 |
+
for tiledata in row:
|
148 |
+
x, w, tile = tiledata
|
149 |
+
|
150 |
+
output = upscale_without_tiling(model, tile)
|
151 |
+
scale_factor = output.width // tile.width
|
152 |
+
|
153 |
+
newrow.append([x * scale_factor, w * scale_factor, output])
|
154 |
+
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
155 |
+
|
156 |
+
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
157 |
+
output = images.combine_grid(newgrid)
|
158 |
+
return output
|
modules/esrgan_model_arch.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this file is taken from https://github.com/xinntao/ESRGAN
|
2 |
+
|
3 |
+
import functools
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def make_layer(block, n_layers):
|
10 |
+
layers = []
|
11 |
+
for _ in range(n_layers):
|
12 |
+
layers.append(block())
|
13 |
+
return nn.Sequential(*layers)
|
14 |
+
|
15 |
+
|
16 |
+
class ResidualDenseBlock_5C(nn.Module):
|
17 |
+
def __init__(self, nf=64, gc=32, bias=True):
|
18 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
19 |
+
# gc: growth channel, i.e. intermediate channels
|
20 |
+
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
21 |
+
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
22 |
+
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
23 |
+
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
24 |
+
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
25 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
26 |
+
|
27 |
+
# initialization
|
28 |
+
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
x1 = self.lrelu(self.conv1(x))
|
32 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
33 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
34 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
35 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
36 |
+
return x5 * 0.2 + x
|
37 |
+
|
38 |
+
|
39 |
+
class RRDB(nn.Module):
|
40 |
+
'''Residual in Residual Dense Block'''
|
41 |
+
|
42 |
+
def __init__(self, nf, gc=32):
|
43 |
+
super(RRDB, self).__init__()
|
44 |
+
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
45 |
+
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
46 |
+
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
out = self.RDB1(x)
|
50 |
+
out = self.RDB2(out)
|
51 |
+
out = self.RDB3(out)
|
52 |
+
return out * 0.2 + x
|
53 |
+
|
54 |
+
|
55 |
+
class RRDBNet(nn.Module):
|
56 |
+
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
57 |
+
super(RRDBNet, self).__init__()
|
58 |
+
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
59 |
+
|
60 |
+
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
61 |
+
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
62 |
+
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
63 |
+
#### upsampling
|
64 |
+
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
65 |
+
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
66 |
+
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
67 |
+
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
68 |
+
|
69 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
fea = self.conv_first(x)
|
73 |
+
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
74 |
+
fea = fea + trunk
|
75 |
+
|
76 |
+
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
77 |
+
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
78 |
+
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
79 |
+
|
80 |
+
return out
|
modules/extras.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import tqdm
|
9 |
+
|
10 |
+
from modules import processing, shared, images, devices, sd_models
|
11 |
+
from modules.shared import opts
|
12 |
+
import modules.gfpgan_model
|
13 |
+
from modules.ui import plaintext_to_html
|
14 |
+
import modules.codeformer_model
|
15 |
+
import piexif
|
16 |
+
import piexif.helper
|
17 |
+
import gradio as gr
|
18 |
+
|
19 |
+
|
20 |
+
cached_images = {}
|
21 |
+
|
22 |
+
|
23 |
+
def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
24 |
+
devices.torch_gc()
|
25 |
+
|
26 |
+
imageArr = []
|
27 |
+
# Also keep track of original file names
|
28 |
+
imageNameArr = []
|
29 |
+
|
30 |
+
if extras_mode == 1:
|
31 |
+
#convert file to pillow image
|
32 |
+
for img in image_folder:
|
33 |
+
image = Image.open(img)
|
34 |
+
imageArr.append(image)
|
35 |
+
imageNameArr.append(os.path.splitext(img.orig_name)[0])
|
36 |
+
else:
|
37 |
+
imageArr.append(image)
|
38 |
+
imageNameArr.append(None)
|
39 |
+
|
40 |
+
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
41 |
+
|
42 |
+
outputs = []
|
43 |
+
for image, image_name in zip(imageArr, imageNameArr):
|
44 |
+
if image is None:
|
45 |
+
return outputs, "Please select an input image.", ''
|
46 |
+
existing_pnginfo = image.info or {}
|
47 |
+
|
48 |
+
image = image.convert("RGB")
|
49 |
+
info = ""
|
50 |
+
|
51 |
+
if gfpgan_visibility > 0:
|
52 |
+
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
53 |
+
res = Image.fromarray(restored_img)
|
54 |
+
|
55 |
+
if gfpgan_visibility < 1.0:
|
56 |
+
res = Image.blend(image, res, gfpgan_visibility)
|
57 |
+
|
58 |
+
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
|
59 |
+
image = res
|
60 |
+
|
61 |
+
if codeformer_visibility > 0:
|
62 |
+
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
63 |
+
res = Image.fromarray(restored_img)
|
64 |
+
|
65 |
+
if codeformer_visibility < 1.0:
|
66 |
+
res = Image.blend(image, res, codeformer_visibility)
|
67 |
+
|
68 |
+
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
69 |
+
image = res
|
70 |
+
|
71 |
+
if resize_mode == 1:
|
72 |
+
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
|
73 |
+
crop_info = " (crop)" if upscaling_crop else ""
|
74 |
+
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
|
75 |
+
|
76 |
+
if upscaling_resize != 1.0:
|
77 |
+
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
78 |
+
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
79 |
+
pixels = tuple(np.array(small).flatten().tolist())
|
80 |
+
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
|
81 |
+
|
82 |
+
c = cached_images.get(key)
|
83 |
+
if c is None:
|
84 |
+
upscaler = shared.sd_upscalers[scaler_index]
|
85 |
+
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
86 |
+
if mode == 1 and crop:
|
87 |
+
cropped = Image.new("RGB", (resize_w, resize_h))
|
88 |
+
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
|
89 |
+
c = cropped
|
90 |
+
cached_images[key] = c
|
91 |
+
|
92 |
+
return c
|
93 |
+
|
94 |
+
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
|
95 |
+
res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
96 |
+
|
97 |
+
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
98 |
+
res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
99 |
+
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
|
100 |
+
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
101 |
+
|
102 |
+
image = res
|
103 |
+
|
104 |
+
while len(cached_images) > 2:
|
105 |
+
del cached_images[next(iter(cached_images.keys()))]
|
106 |
+
|
107 |
+
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
108 |
+
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
|
109 |
+
forced_filename=image_name if opts.use_original_name_batch else None)
|
110 |
+
|
111 |
+
if opts.enable_pnginfo:
|
112 |
+
image.info = existing_pnginfo
|
113 |
+
image.info["extras"] = info
|
114 |
+
|
115 |
+
outputs.append(image)
|
116 |
+
|
117 |
+
devices.torch_gc()
|
118 |
+
|
119 |
+
return outputs, plaintext_to_html(info), ''
|
120 |
+
|
121 |
+
|
122 |
+
def run_pnginfo(image):
|
123 |
+
if image is None:
|
124 |
+
return '', '', ''
|
125 |
+
|
126 |
+
items = image.info
|
127 |
+
geninfo = ''
|
128 |
+
|
129 |
+
if "exif" in image.info:
|
130 |
+
exif = piexif.load(image.info["exif"])
|
131 |
+
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
132 |
+
try:
|
133 |
+
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
134 |
+
except ValueError:
|
135 |
+
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
136 |
+
|
137 |
+
items['exif comment'] = exif_comment
|
138 |
+
geninfo = exif_comment
|
139 |
+
|
140 |
+
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
141 |
+
'loop', 'background', 'timestamp', 'duration']:
|
142 |
+
items.pop(field, None)
|
143 |
+
|
144 |
+
geninfo = items.get('parameters', geninfo)
|
145 |
+
|
146 |
+
info = ''
|
147 |
+
for key, text in items.items():
|
148 |
+
info += f"""
|
149 |
+
<div>
|
150 |
+
<p><b>{plaintext_to_html(str(key))}</b></p>
|
151 |
+
<p>{plaintext_to_html(str(text))}</p>
|
152 |
+
</div>
|
153 |
+
""".strip()+"\n"
|
154 |
+
|
155 |
+
if len(info) == 0:
|
156 |
+
message = "Nothing found in the image."
|
157 |
+
info = f"<div><p>{message}<p></div>"
|
158 |
+
|
159 |
+
return '', geninfo, info
|
160 |
+
|
161 |
+
|
162 |
+
def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
|
163 |
+
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
|
164 |
+
def weighted_sum(theta0, theta1, alpha):
|
165 |
+
return ((1 - alpha) * theta0) + (alpha * theta1)
|
166 |
+
|
167 |
+
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
|
168 |
+
def sigmoid(theta0, theta1, alpha):
|
169 |
+
alpha = alpha * alpha * (3 - (2 * alpha))
|
170 |
+
return theta0 + ((theta1 - theta0) * alpha)
|
171 |
+
|
172 |
+
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
|
173 |
+
def inv_sigmoid(theta0, theta1, alpha):
|
174 |
+
import math
|
175 |
+
alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
|
176 |
+
return theta0 + ((theta1 - theta0) * alpha)
|
177 |
+
|
178 |
+
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
179 |
+
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
|
180 |
+
|
181 |
+
print(f"Loading {primary_model_info.filename}...")
|
182 |
+
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
|
183 |
+
|
184 |
+
print(f"Loading {secondary_model_info.filename}...")
|
185 |
+
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
|
186 |
+
|
187 |
+
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
|
188 |
+
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
|
189 |
+
|
190 |
+
theta_funcs = {
|
191 |
+
"Weighted Sum": weighted_sum,
|
192 |
+
"Sigmoid": sigmoid,
|
193 |
+
"Inverse Sigmoid": inv_sigmoid,
|
194 |
+
}
|
195 |
+
theta_func = theta_funcs[interp_method]
|
196 |
+
|
197 |
+
print(f"Merging...")
|
198 |
+
for key in tqdm.tqdm(theta_0.keys()):
|
199 |
+
if 'model' in key and key in theta_1:
|
200 |
+
theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
|
201 |
+
if save_as_half:
|
202 |
+
theta_0[key] = theta_0[key].half()
|
203 |
+
|
204 |
+
for key in theta_1.keys():
|
205 |
+
if 'model' in key and key not in theta_0:
|
206 |
+
theta_0[key] = theta_1[key]
|
207 |
+
if save_as_half:
|
208 |
+
theta_0[key] = theta_0[key].half()
|
209 |
+
|
210 |
+
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
211 |
+
|
212 |
+
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
|
213 |
+
filename = filename if custom_name == '' else (custom_name + '.ckpt')
|
214 |
+
output_modelname = os.path.join(ckpt_dir, filename)
|
215 |
+
|
216 |
+
print(f"Saving to {output_modelname}...")
|
217 |
+
torch.save(primary_model, output_modelname)
|
218 |
+
|
219 |
+
sd_models.list_models()
|
220 |
+
|
221 |
+
print(f"Checkpoint saved.")
|
222 |
+
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
|
modules/face_restoration.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules import shared
|
2 |
+
|
3 |
+
|
4 |
+
class FaceRestoration:
|
5 |
+
def name(self):
|
6 |
+
return "None"
|
7 |
+
|
8 |
+
def restore(self, np_image):
|
9 |
+
return np_image
|
10 |
+
|
11 |
+
|
12 |
+
def restore_faces(np_image):
|
13 |
+
face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
|
14 |
+
if len(face_restorers) == 0:
|
15 |
+
return np_image
|
16 |
+
|
17 |
+
face_restorer = face_restorers[0]
|
18 |
+
|
19 |
+
return face_restorer.restore(np_image)
|
modules/generation_parameters_copypaste.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import gradio as gr
|
4 |
+
from modules.shared import script_path
|
5 |
+
from modules import shared
|
6 |
+
|
7 |
+
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
|
8 |
+
re_param = re.compile(re_param_code)
|
9 |
+
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
10 |
+
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
11 |
+
type_of_gr_update = type(gr.update())
|
12 |
+
|
13 |
+
|
14 |
+
def parse_generation_parameters(x: str):
|
15 |
+
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
16 |
+
```
|
17 |
+
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
|
18 |
+
Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
|
19 |
+
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
|
20 |
+
```
|
21 |
+
|
22 |
+
returns a dict with field values
|
23 |
+
"""
|
24 |
+
|
25 |
+
res = {}
|
26 |
+
|
27 |
+
prompt = ""
|
28 |
+
negative_prompt = ""
|
29 |
+
|
30 |
+
done_with_prompt = False
|
31 |
+
|
32 |
+
*lines, lastline = x.strip().split("\n")
|
33 |
+
if not re_params.match(lastline):
|
34 |
+
lines.append(lastline)
|
35 |
+
lastline = ''
|
36 |
+
|
37 |
+
for i, line in enumerate(lines):
|
38 |
+
line = line.strip()
|
39 |
+
if line.startswith("Negative prompt:"):
|
40 |
+
done_with_prompt = True
|
41 |
+
line = line[16:].strip()
|
42 |
+
|
43 |
+
if done_with_prompt:
|
44 |
+
negative_prompt += ("" if negative_prompt == "" else "\n") + line
|
45 |
+
else:
|
46 |
+
prompt += ("" if prompt == "" else "\n") + line
|
47 |
+
|
48 |
+
if len(prompt) > 0:
|
49 |
+
res["Prompt"] = prompt
|
50 |
+
|
51 |
+
if len(negative_prompt) > 0:
|
52 |
+
res["Negative prompt"] = negative_prompt
|
53 |
+
|
54 |
+
for k, v in re_param.findall(lastline):
|
55 |
+
m = re_imagesize.match(v)
|
56 |
+
if m is not None:
|
57 |
+
res[k+"-1"] = m.group(1)
|
58 |
+
res[k+"-2"] = m.group(2)
|
59 |
+
else:
|
60 |
+
res[k] = v
|
61 |
+
|
62 |
+
return res
|
63 |
+
|
64 |
+
|
65 |
+
def connect_paste(button, paste_fields, input_comp, js=None):
|
66 |
+
def paste_func(prompt):
|
67 |
+
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
68 |
+
filename = os.path.join(script_path, "params.txt")
|
69 |
+
if os.path.exists(filename):
|
70 |
+
with open(filename, "r", encoding="utf8") as file:
|
71 |
+
prompt = file.read()
|
72 |
+
|
73 |
+
params = parse_generation_parameters(prompt)
|
74 |
+
res = []
|
75 |
+
|
76 |
+
for output, key in paste_fields:
|
77 |
+
if callable(key):
|
78 |
+
v = key(params)
|
79 |
+
else:
|
80 |
+
v = params.get(key, None)
|
81 |
+
|
82 |
+
if v is None:
|
83 |
+
res.append(gr.update())
|
84 |
+
elif isinstance(v, type_of_gr_update):
|
85 |
+
res.append(v)
|
86 |
+
else:
|
87 |
+
try:
|
88 |
+
valtype = type(output.value)
|
89 |
+
val = valtype(v)
|
90 |
+
res.append(gr.update(value=val))
|
91 |
+
except Exception:
|
92 |
+
res.append(gr.update())
|
93 |
+
|
94 |
+
return res
|
95 |
+
|
96 |
+
button.click(
|
97 |
+
fn=paste_func,
|
98 |
+
_js=js,
|
99 |
+
inputs=[input_comp],
|
100 |
+
outputs=[x[0] for x in paste_fields],
|
101 |
+
)
|
modules/gfpgan_model.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import facexlib
|
6 |
+
import gfpgan
|
7 |
+
|
8 |
+
import modules.face_restoration
|
9 |
+
from modules import shared, devices, modelloader
|
10 |
+
from modules.paths import models_path
|
11 |
+
|
12 |
+
model_dir = "GFPGAN"
|
13 |
+
user_path = None
|
14 |
+
model_path = os.path.join(models_path, model_dir)
|
15 |
+
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
16 |
+
have_gfpgan = False
|
17 |
+
loaded_gfpgan_model = None
|
18 |
+
|
19 |
+
|
20 |
+
def gfpgann():
|
21 |
+
global loaded_gfpgan_model
|
22 |
+
global model_path
|
23 |
+
if loaded_gfpgan_model is not None:
|
24 |
+
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
25 |
+
return loaded_gfpgan_model
|
26 |
+
|
27 |
+
if gfpgan_constructor is None:
|
28 |
+
return None
|
29 |
+
|
30 |
+
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
31 |
+
if len(models) == 1 and "http" in models[0]:
|
32 |
+
model_file = models[0]
|
33 |
+
elif len(models) != 0:
|
34 |
+
latest_file = max(models, key=os.path.getctime)
|
35 |
+
model_file = latest_file
|
36 |
+
else:
|
37 |
+
print("Unable to load gfpgan model!")
|
38 |
+
return None
|
39 |
+
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
40 |
+
loaded_gfpgan_model = model
|
41 |
+
|
42 |
+
return model
|
43 |
+
|
44 |
+
|
45 |
+
def send_model_to(model, device):
|
46 |
+
model.gfpgan.to(device)
|
47 |
+
model.face_helper.face_det.to(device)
|
48 |
+
model.face_helper.face_parse.to(device)
|
49 |
+
|
50 |
+
|
51 |
+
def gfpgan_fix_faces(np_image):
|
52 |
+
model = gfpgann()
|
53 |
+
if model is None:
|
54 |
+
return np_image
|
55 |
+
|
56 |
+
send_model_to(model, devices.device_gfpgan)
|
57 |
+
|
58 |
+
np_image_bgr = np_image[:, :, ::-1]
|
59 |
+
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
60 |
+
np_image = gfpgan_output_bgr[:, :, ::-1]
|
61 |
+
|
62 |
+
model.face_helper.clean_all()
|
63 |
+
|
64 |
+
if shared.opts.face_restoration_unload:
|
65 |
+
send_model_to(model, devices.cpu)
|
66 |
+
|
67 |
+
return np_image
|
68 |
+
|
69 |
+
|
70 |
+
gfpgan_constructor = None
|
71 |
+
|
72 |
+
|
73 |
+
def setup_model(dirname):
|
74 |
+
global model_path
|
75 |
+
if not os.path.exists(model_path):
|
76 |
+
os.makedirs(model_path)
|
77 |
+
|
78 |
+
try:
|
79 |
+
from gfpgan import GFPGANer
|
80 |
+
from facexlib import detection, parsing
|
81 |
+
global user_path
|
82 |
+
global have_gfpgan
|
83 |
+
global gfpgan_constructor
|
84 |
+
|
85 |
+
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
86 |
+
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
87 |
+
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
88 |
+
|
89 |
+
def my_load_file_from_url(**kwargs):
|
90 |
+
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
|
91 |
+
|
92 |
+
def facex_load_file_from_url(**kwargs):
|
93 |
+
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
|
94 |
+
|
95 |
+
def facex_load_file_from_url2(**kwargs):
|
96 |
+
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
|
97 |
+
|
98 |
+
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
99 |
+
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
100 |
+
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
101 |
+
user_path = dirname
|
102 |
+
have_gfpgan = True
|
103 |
+
gfpgan_constructor = GFPGANer
|
104 |
+
|
105 |
+
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
106 |
+
def name(self):
|
107 |
+
return "GFPGAN"
|
108 |
+
|
109 |
+
def restore(self, np_image):
|
110 |
+
return gfpgan_fix_faces(np_image)
|
111 |
+
|
112 |
+
shared.face_restorers.append(FaceRestorerGFPGAN())
|
113 |
+
except Exception:
|
114 |
+
print("Error setting up GFPGAN:", file=sys.stderr)
|
115 |
+
print(traceback.format_exc(), file=sys.stderr)
|
modules/hypernetworks/hypernetwork.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import glob
|
3 |
+
import html
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import traceback
|
7 |
+
import tqdm
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from ldm.util import default
|
12 |
+
from modules import devices, shared, processing, sd_models
|
13 |
+
import torch
|
14 |
+
from torch import einsum
|
15 |
+
from einops import rearrange, repeat
|
16 |
+
import modules.textual_inversion.dataset
|
17 |
+
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
18 |
+
|
19 |
+
|
20 |
+
class HypernetworkModule(torch.nn.Module):
|
21 |
+
multiplier = 1.0
|
22 |
+
|
23 |
+
def __init__(self, dim, state_dict=None):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.linear1 = torch.nn.Linear(dim, dim * 2)
|
27 |
+
self.linear2 = torch.nn.Linear(dim * 2, dim)
|
28 |
+
|
29 |
+
if state_dict is not None:
|
30 |
+
self.load_state_dict(state_dict, strict=True)
|
31 |
+
else:
|
32 |
+
|
33 |
+
self.linear1.weight.data.normal_(mean=0.0, std=0.01)
|
34 |
+
self.linear1.bias.data.zero_()
|
35 |
+
self.linear2.weight.data.normal_(mean=0.0, std=0.01)
|
36 |
+
self.linear2.bias.data.zero_()
|
37 |
+
|
38 |
+
self.to(devices.device)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return x + (self.linear2(self.linear1(x))) * self.multiplier
|
42 |
+
|
43 |
+
|
44 |
+
def apply_strength(value=None):
|
45 |
+
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
46 |
+
|
47 |
+
|
48 |
+
class Hypernetwork:
|
49 |
+
filename = None
|
50 |
+
name = None
|
51 |
+
|
52 |
+
def __init__(self, name=None, enable_sizes=None):
|
53 |
+
self.filename = None
|
54 |
+
self.name = name
|
55 |
+
self.layers = {}
|
56 |
+
self.step = 0
|
57 |
+
self.sd_checkpoint = None
|
58 |
+
self.sd_checkpoint_name = None
|
59 |
+
|
60 |
+
for size in enable_sizes or []:
|
61 |
+
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
|
62 |
+
|
63 |
+
def weights(self):
|
64 |
+
res = []
|
65 |
+
|
66 |
+
for k, layers in self.layers.items():
|
67 |
+
for layer in layers:
|
68 |
+
layer.train()
|
69 |
+
res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
|
70 |
+
|
71 |
+
return res
|
72 |
+
|
73 |
+
def save(self, filename):
|
74 |
+
state_dict = {}
|
75 |
+
|
76 |
+
for k, v in self.layers.items():
|
77 |
+
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
78 |
+
|
79 |
+
state_dict['step'] = self.step
|
80 |
+
state_dict['name'] = self.name
|
81 |
+
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
82 |
+
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
83 |
+
|
84 |
+
torch.save(state_dict, filename)
|
85 |
+
|
86 |
+
def load(self, filename):
|
87 |
+
self.filename = filename
|
88 |
+
if self.name is None:
|
89 |
+
self.name = os.path.splitext(os.path.basename(filename))[0]
|
90 |
+
|
91 |
+
state_dict = torch.load(filename, map_location='cpu')
|
92 |
+
|
93 |
+
for size, sd in state_dict.items():
|
94 |
+
if type(size) == int:
|
95 |
+
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
|
96 |
+
|
97 |
+
self.name = state_dict.get('name', self.name)
|
98 |
+
self.step = state_dict.get('step', 0)
|
99 |
+
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
100 |
+
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
101 |
+
|
102 |
+
|
103 |
+
def list_hypernetworks(path):
|
104 |
+
res = {}
|
105 |
+
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
|
106 |
+
name = os.path.splitext(os.path.basename(filename))[0]
|
107 |
+
res[name] = filename
|
108 |
+
return res
|
109 |
+
|
110 |
+
|
111 |
+
def load_hypernetwork(filename):
|
112 |
+
path = shared.hypernetworks.get(filename, None)
|
113 |
+
if path is not None:
|
114 |
+
print(f"Loading hypernetwork {filename}")
|
115 |
+
try:
|
116 |
+
shared.loaded_hypernetwork = Hypernetwork()
|
117 |
+
shared.loaded_hypernetwork.load(path)
|
118 |
+
|
119 |
+
except Exception:
|
120 |
+
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
121 |
+
print(traceback.format_exc(), file=sys.stderr)
|
122 |
+
else:
|
123 |
+
if shared.loaded_hypernetwork is not None:
|
124 |
+
print(f"Unloading hypernetwork")
|
125 |
+
|
126 |
+
shared.loaded_hypernetwork = None
|
127 |
+
|
128 |
+
|
129 |
+
def find_closest_hypernetwork_name(search: str):
|
130 |
+
if not search:
|
131 |
+
return None
|
132 |
+
search = search.lower()
|
133 |
+
applicable = [name for name in shared.hypernetworks if search in name.lower()]
|
134 |
+
if not applicable:
|
135 |
+
return None
|
136 |
+
applicable = sorted(applicable, key=lambda name: len(name))
|
137 |
+
return applicable[0]
|
138 |
+
|
139 |
+
|
140 |
+
def apply_hypernetwork(hypernetwork, context, layer=None):
|
141 |
+
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
142 |
+
|
143 |
+
if hypernetwork_layers is None:
|
144 |
+
return context, context
|
145 |
+
|
146 |
+
if layer is not None:
|
147 |
+
layer.hyper_k = hypernetwork_layers[0]
|
148 |
+
layer.hyper_v = hypernetwork_layers[1]
|
149 |
+
|
150 |
+
context_k = hypernetwork_layers[0](context)
|
151 |
+
context_v = hypernetwork_layers[1](context)
|
152 |
+
return context_k, context_v
|
153 |
+
|
154 |
+
|
155 |
+
def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
156 |
+
h = self.heads
|
157 |
+
|
158 |
+
q = self.to_q(x)
|
159 |
+
context = default(context, x)
|
160 |
+
|
161 |
+
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
|
162 |
+
k = self.to_k(context_k)
|
163 |
+
v = self.to_v(context_v)
|
164 |
+
|
165 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
166 |
+
|
167 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
168 |
+
|
169 |
+
if mask is not None:
|
170 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
171 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
172 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
173 |
+
sim.masked_fill_(~mask, max_neg_value)
|
174 |
+
|
175 |
+
# attention, what we cannot get enough of
|
176 |
+
attn = sim.softmax(dim=-1)
|
177 |
+
|
178 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
179 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
180 |
+
return self.to_out(out)
|
181 |
+
|
182 |
+
|
183 |
+
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
|
184 |
+
assert hypernetwork_name, 'hypernetwork not selected'
|
185 |
+
|
186 |
+
path = shared.hypernetworks.get(hypernetwork_name, None)
|
187 |
+
shared.loaded_hypernetwork = Hypernetwork()
|
188 |
+
shared.loaded_hypernetwork.load(path)
|
189 |
+
|
190 |
+
shared.state.textinfo = "Initializing hypernetwork training..."
|
191 |
+
shared.state.job_count = steps
|
192 |
+
|
193 |
+
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
194 |
+
|
195 |
+
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
196 |
+
unload = shared.opts.unload_models_when_training
|
197 |
+
|
198 |
+
if save_hypernetwork_every > 0:
|
199 |
+
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
200 |
+
os.makedirs(hypernetwork_dir, exist_ok=True)
|
201 |
+
else:
|
202 |
+
hypernetwork_dir = None
|
203 |
+
|
204 |
+
if create_image_every > 0:
|
205 |
+
images_dir = os.path.join(log_directory, "images")
|
206 |
+
os.makedirs(images_dir, exist_ok=True)
|
207 |
+
else:
|
208 |
+
images_dir = None
|
209 |
+
|
210 |
+
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
211 |
+
with torch.autocast("cuda"):
|
212 |
+
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
|
213 |
+
|
214 |
+
if unload:
|
215 |
+
shared.sd_model.cond_stage_model.to(devices.cpu)
|
216 |
+
shared.sd_model.first_stage_model.to(devices.cpu)
|
217 |
+
|
218 |
+
hypernetwork = shared.loaded_hypernetwork
|
219 |
+
weights = hypernetwork.weights()
|
220 |
+
for weight in weights:
|
221 |
+
weight.requires_grad = True
|
222 |
+
|
223 |
+
losses = torch.zeros((32,))
|
224 |
+
|
225 |
+
last_saved_file = "<none>"
|
226 |
+
last_saved_image = "<none>"
|
227 |
+
|
228 |
+
ititial_step = hypernetwork.step or 0
|
229 |
+
if ititial_step > steps:
|
230 |
+
return hypernetwork, filename
|
231 |
+
|
232 |
+
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
233 |
+
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
234 |
+
|
235 |
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
236 |
+
for i, entry in pbar:
|
237 |
+
hypernetwork.step = i + ititial_step
|
238 |
+
|
239 |
+
scheduler.apply(optimizer, hypernetwork.step)
|
240 |
+
if scheduler.finished:
|
241 |
+
break
|
242 |
+
|
243 |
+
if shared.state.interrupted:
|
244 |
+
break
|
245 |
+
|
246 |
+
with torch.autocast("cuda"):
|
247 |
+
cond = entry.cond.to(devices.device)
|
248 |
+
x = entry.latent.to(devices.device)
|
249 |
+
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
|
250 |
+
del x
|
251 |
+
del cond
|
252 |
+
|
253 |
+
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
254 |
+
|
255 |
+
optimizer.zero_grad()
|
256 |
+
loss.backward()
|
257 |
+
optimizer.step()
|
258 |
+
|
259 |
+
pbar.set_description(f"loss: {losses.mean():.7f}")
|
260 |
+
|
261 |
+
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
262 |
+
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
263 |
+
hypernetwork.save(last_saved_file)
|
264 |
+
|
265 |
+
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
266 |
+
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
267 |
+
|
268 |
+
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
|
269 |
+
|
270 |
+
optimizer.zero_grad()
|
271 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
272 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
273 |
+
|
274 |
+
p = processing.StableDiffusionProcessingTxt2Img(
|
275 |
+
sd_model=shared.sd_model,
|
276 |
+
prompt=preview_text,
|
277 |
+
steps=20,
|
278 |
+
do_not_save_grid=True,
|
279 |
+
do_not_save_samples=True,
|
280 |
+
)
|
281 |
+
|
282 |
+
processed = processing.process_images(p)
|
283 |
+
image = processed.images[0] if len(processed.images)>0 else None
|
284 |
+
|
285 |
+
if unload:
|
286 |
+
shared.sd_model.cond_stage_model.to(devices.cpu)
|
287 |
+
shared.sd_model.first_stage_model.to(devices.cpu)
|
288 |
+
|
289 |
+
if image is not None:
|
290 |
+
shared.state.current_image = image
|
291 |
+
image.save(last_saved_image)
|
292 |
+
last_saved_image += f", prompt: {preview_text}"
|
293 |
+
|
294 |
+
shared.state.job_no = hypernetwork.step
|
295 |
+
|
296 |
+
shared.state.textinfo = f"""
|
297 |
+
<p>
|
298 |
+
Loss: {losses.mean():.7f}<br/>
|
299 |
+
Step: {hypernetwork.step}<br/>
|
300 |
+
Last prompt: {html.escape(entry.cond_text)}<br/>
|
301 |
+
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
302 |
+
Last saved image: {html.escape(last_saved_image)}<br/>
|
303 |
+
</p>
|
304 |
+
"""
|
305 |
+
|
306 |
+
checkpoint = sd_models.select_checkpoint()
|
307 |
+
|
308 |
+
hypernetwork.sd_checkpoint = checkpoint.hash
|
309 |
+
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
310 |
+
hypernetwork.save(filename)
|
311 |
+
|
312 |
+
return hypernetwork, filename
|
313 |
+
|
314 |
+
|
modules/hypernetworks/ui.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
import os
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
import modules.textual_inversion.textual_inversion
|
7 |
+
import modules.textual_inversion.preprocess
|
8 |
+
from modules import sd_hijack, shared, devices
|
9 |
+
from modules.hypernetworks import hypernetwork
|
10 |
+
|
11 |
+
|
12 |
+
def create_hypernetwork(name, enable_sizes):
|
13 |
+
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
14 |
+
assert not os.path.exists(fn), f"file {fn} already exists"
|
15 |
+
|
16 |
+
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
|
17 |
+
hypernet.save(fn)
|
18 |
+
|
19 |
+
shared.reload_hypernetworks()
|
20 |
+
|
21 |
+
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
|
22 |
+
|
23 |
+
|
24 |
+
def train_hypernetwork(*args):
|
25 |
+
|
26 |
+
initial_hypernetwork = shared.loaded_hypernetwork
|
27 |
+
|
28 |
+
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
29 |
+
|
30 |
+
try:
|
31 |
+
sd_hijack.undo_optimizations()
|
32 |
+
|
33 |
+
hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
|
34 |
+
|
35 |
+
res = f"""
|
36 |
+
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
|
37 |
+
Hypernetwork saved to {html.escape(filename)}
|
38 |
+
"""
|
39 |
+
return res, ""
|
40 |
+
except Exception:
|
41 |
+
raise
|
42 |
+
finally:
|
43 |
+
shared.loaded_hypernetwork = initial_hypernetwork
|
44 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
45 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
46 |
+
sd_hijack.apply_optimizations()
|
47 |
+
|
modules/images.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
from collections import namedtuple
|
5 |
+
import re
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import piexif
|
9 |
+
import piexif.helper
|
10 |
+
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
11 |
+
from fonts.ttf import Roboto
|
12 |
+
import string
|
13 |
+
|
14 |
+
from modules import sd_samplers, shared
|
15 |
+
from modules.shared import opts, cmd_opts
|
16 |
+
|
17 |
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
18 |
+
|
19 |
+
|
20 |
+
def image_grid(imgs, batch_size=1, rows=None):
|
21 |
+
if rows is None:
|
22 |
+
if opts.n_rows > 0:
|
23 |
+
rows = opts.n_rows
|
24 |
+
elif opts.n_rows == 0:
|
25 |
+
rows = batch_size
|
26 |
+
else:
|
27 |
+
rows = math.sqrt(len(imgs))
|
28 |
+
rows = round(rows)
|
29 |
+
|
30 |
+
cols = math.ceil(len(imgs) / rows)
|
31 |
+
|
32 |
+
w, h = imgs[0].size
|
33 |
+
grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
|
34 |
+
|
35 |
+
for i, img in enumerate(imgs):
|
36 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
37 |
+
|
38 |
+
return grid
|
39 |
+
|
40 |
+
|
41 |
+
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
42 |
+
|
43 |
+
|
44 |
+
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
45 |
+
w = image.width
|
46 |
+
h = image.height
|
47 |
+
|
48 |
+
non_overlap_width = tile_w - overlap
|
49 |
+
non_overlap_height = tile_h - overlap
|
50 |
+
|
51 |
+
cols = math.ceil((w - overlap) / non_overlap_width)
|
52 |
+
rows = math.ceil((h - overlap) / non_overlap_height)
|
53 |
+
|
54 |
+
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
55 |
+
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
56 |
+
|
57 |
+
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
58 |
+
for row in range(rows):
|
59 |
+
row_images = []
|
60 |
+
|
61 |
+
y = int(row * dy)
|
62 |
+
|
63 |
+
if y + tile_h >= h:
|
64 |
+
y = h - tile_h
|
65 |
+
|
66 |
+
for col in range(cols):
|
67 |
+
x = int(col * dx)
|
68 |
+
|
69 |
+
if x + tile_w >= w:
|
70 |
+
x = w - tile_w
|
71 |
+
|
72 |
+
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
73 |
+
|
74 |
+
row_images.append([x, tile_w, tile])
|
75 |
+
|
76 |
+
grid.tiles.append([y, tile_h, row_images])
|
77 |
+
|
78 |
+
return grid
|
79 |
+
|
80 |
+
|
81 |
+
def combine_grid(grid):
|
82 |
+
def make_mask_image(r):
|
83 |
+
r = r * 255 / grid.overlap
|
84 |
+
r = r.astype(np.uint8)
|
85 |
+
return Image.fromarray(r, 'L')
|
86 |
+
|
87 |
+
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
88 |
+
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
89 |
+
|
90 |
+
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
91 |
+
for y, h, row in grid.tiles:
|
92 |
+
combined_row = Image.new("RGB", (grid.image_w, h))
|
93 |
+
for x, w, tile in row:
|
94 |
+
if x == 0:
|
95 |
+
combined_row.paste(tile, (0, 0))
|
96 |
+
continue
|
97 |
+
|
98 |
+
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
|
99 |
+
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
|
100 |
+
|
101 |
+
if y == 0:
|
102 |
+
combined_image.paste(combined_row, (0, 0))
|
103 |
+
continue
|
104 |
+
|
105 |
+
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
|
106 |
+
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
|
107 |
+
|
108 |
+
return combined_image
|
109 |
+
|
110 |
+
|
111 |
+
class GridAnnotation:
|
112 |
+
def __init__(self, text='', is_active=True):
|
113 |
+
self.text = text
|
114 |
+
self.is_active = is_active
|
115 |
+
self.size = None
|
116 |
+
|
117 |
+
|
118 |
+
def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
119 |
+
def wrap(drawing, text, font, line_length):
|
120 |
+
lines = ['']
|
121 |
+
for word in text.split():
|
122 |
+
line = f'{lines[-1]} {word}'.strip()
|
123 |
+
if drawing.textlength(line, font=font) <= line_length:
|
124 |
+
lines[-1] = line
|
125 |
+
else:
|
126 |
+
lines.append(word)
|
127 |
+
return lines
|
128 |
+
|
129 |
+
def draw_texts(drawing, draw_x, draw_y, lines):
|
130 |
+
for i, line in enumerate(lines):
|
131 |
+
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
132 |
+
|
133 |
+
if not line.is_active:
|
134 |
+
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
135 |
+
|
136 |
+
draw_y += line.size[1] + line_spacing
|
137 |
+
|
138 |
+
fontsize = (width + height) // 25
|
139 |
+
line_spacing = fontsize // 2
|
140 |
+
|
141 |
+
try:
|
142 |
+
fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
|
143 |
+
except Exception:
|
144 |
+
fnt = ImageFont.truetype(Roboto, fontsize)
|
145 |
+
|
146 |
+
color_active = (0, 0, 0)
|
147 |
+
color_inactive = (153, 153, 153)
|
148 |
+
|
149 |
+
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
|
150 |
+
|
151 |
+
cols = im.width // width
|
152 |
+
rows = im.height // height
|
153 |
+
|
154 |
+
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
|
155 |
+
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
|
156 |
+
|
157 |
+
calc_img = Image.new("RGB", (1, 1), "white")
|
158 |
+
calc_d = ImageDraw.Draw(calc_img)
|
159 |
+
|
160 |
+
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
|
161 |
+
items = [] + texts
|
162 |
+
texts.clear()
|
163 |
+
|
164 |
+
for line in items:
|
165 |
+
wrapped = wrap(calc_d, line.text, fnt, allowed_width)
|
166 |
+
texts += [GridAnnotation(x, line.is_active) for x in wrapped]
|
167 |
+
|
168 |
+
for line in texts:
|
169 |
+
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
|
170 |
+
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
171 |
+
|
172 |
+
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
173 |
+
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
|
174 |
+
ver_texts]
|
175 |
+
|
176 |
+
pad_top = max(hor_text_heights) + line_spacing * 2
|
177 |
+
|
178 |
+
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
|
179 |
+
result.paste(im, (pad_left, pad_top))
|
180 |
+
|
181 |
+
d = ImageDraw.Draw(result)
|
182 |
+
|
183 |
+
for col in range(cols):
|
184 |
+
x = pad_left + width * col + width / 2
|
185 |
+
y = pad_top / 2 - hor_text_heights[col] / 2
|
186 |
+
|
187 |
+
draw_texts(d, x, y, hor_texts[col])
|
188 |
+
|
189 |
+
for row in range(rows):
|
190 |
+
x = pad_left / 2
|
191 |
+
y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
|
192 |
+
|
193 |
+
draw_texts(d, x, y, ver_texts[row])
|
194 |
+
|
195 |
+
return result
|
196 |
+
|
197 |
+
|
198 |
+
def draw_prompt_matrix(im, width, height, all_prompts):
|
199 |
+
prompts = all_prompts[1:]
|
200 |
+
boundary = math.ceil(len(prompts) / 2)
|
201 |
+
|
202 |
+
prompts_horiz = prompts[:boundary]
|
203 |
+
prompts_vert = prompts[boundary:]
|
204 |
+
|
205 |
+
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
206 |
+
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
207 |
+
|
208 |
+
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
|
209 |
+
|
210 |
+
|
211 |
+
def resize_image(resize_mode, im, width, height):
|
212 |
+
def resize(im, w, h):
|
213 |
+
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
|
214 |
+
return im.resize((w, h), resample=LANCZOS)
|
215 |
+
|
216 |
+
scale = max(w / im.width, h / im.height)
|
217 |
+
|
218 |
+
if scale > 1.0:
|
219 |
+
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
|
220 |
+
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
|
221 |
+
|
222 |
+
upscaler = upscalers[0]
|
223 |
+
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
224 |
+
|
225 |
+
if im.width != w or im.height != h:
|
226 |
+
im = im.resize((w, h), resample=LANCZOS)
|
227 |
+
|
228 |
+
return im
|
229 |
+
|
230 |
+
if resize_mode == 0:
|
231 |
+
res = resize(im, width, height)
|
232 |
+
|
233 |
+
elif resize_mode == 1:
|
234 |
+
ratio = width / height
|
235 |
+
src_ratio = im.width / im.height
|
236 |
+
|
237 |
+
src_w = width if ratio > src_ratio else im.width * height // im.height
|
238 |
+
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
239 |
+
|
240 |
+
resized = resize(im, src_w, src_h)
|
241 |
+
res = Image.new("RGB", (width, height))
|
242 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
243 |
+
|
244 |
+
else:
|
245 |
+
ratio = width / height
|
246 |
+
src_ratio = im.width / im.height
|
247 |
+
|
248 |
+
src_w = width if ratio < src_ratio else im.width * height // im.height
|
249 |
+
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
250 |
+
|
251 |
+
resized = resize(im, src_w, src_h)
|
252 |
+
res = Image.new("RGB", (width, height))
|
253 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
254 |
+
|
255 |
+
if ratio < src_ratio:
|
256 |
+
fill_height = height // 2 - src_h // 2
|
257 |
+
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
258 |
+
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
259 |
+
elif ratio > src_ratio:
|
260 |
+
fill_width = width // 2 - src_w // 2
|
261 |
+
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
262 |
+
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
263 |
+
|
264 |
+
return res
|
265 |
+
|
266 |
+
|
267 |
+
invalid_filename_chars = '<>:"/\\|?*\n'
|
268 |
+
invalid_filename_prefix = ' '
|
269 |
+
invalid_filename_postfix = ' .'
|
270 |
+
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
271 |
+
max_filename_part_length = 128
|
272 |
+
|
273 |
+
|
274 |
+
def sanitize_filename_part(text, replace_spaces=True):
|
275 |
+
if replace_spaces:
|
276 |
+
text = text.replace(' ', '_')
|
277 |
+
|
278 |
+
text = text.translate({ord(x): '_' for x in invalid_filename_chars})
|
279 |
+
text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
|
280 |
+
text = text.rstrip(invalid_filename_postfix)
|
281 |
+
return text
|
282 |
+
|
283 |
+
|
284 |
+
def apply_filename_pattern(x, p, seed, prompt):
|
285 |
+
max_prompt_words = opts.directories_max_prompt_words
|
286 |
+
|
287 |
+
if seed is not None:
|
288 |
+
x = x.replace("[seed]", str(seed))
|
289 |
+
|
290 |
+
if p is not None:
|
291 |
+
x = x.replace("[steps]", str(p.steps))
|
292 |
+
x = x.replace("[cfg]", str(p.cfg_scale))
|
293 |
+
x = x.replace("[width]", str(p.width))
|
294 |
+
x = x.replace("[height]", str(p.height))
|
295 |
+
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
|
296 |
+
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
297 |
+
|
298 |
+
x = x.replace("[model_hash]", getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash))
|
299 |
+
x = x.replace("[date]", datetime.date.today().isoformat())
|
300 |
+
x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
|
301 |
+
x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp))
|
302 |
+
|
303 |
+
# Apply [prompt] at last. Because it may contain any replacement word.^M
|
304 |
+
if prompt is not None:
|
305 |
+
x = x.replace("[prompt]", sanitize_filename_part(prompt))
|
306 |
+
if "[prompt_no_styles]" in x:
|
307 |
+
prompt_no_style = prompt
|
308 |
+
for style in shared.prompt_styles.get_style_prompts(p.styles):
|
309 |
+
if len(style) > 0:
|
310 |
+
style_parts = [y for y in style.split("{prompt}")]
|
311 |
+
for part in style_parts:
|
312 |
+
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
|
313 |
+
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
|
314 |
+
x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
|
315 |
+
|
316 |
+
x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
|
317 |
+
if "[prompt_words]" in x:
|
318 |
+
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
|
319 |
+
if len(words) == 0:
|
320 |
+
words = ["empty"]
|
321 |
+
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
|
322 |
+
|
323 |
+
if cmd_opts.hide_ui_dir_config:
|
324 |
+
x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
|
325 |
+
|
326 |
+
return x
|
327 |
+
|
328 |
+
|
329 |
+
def get_next_sequence_number(path, basename):
|
330 |
+
"""
|
331 |
+
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
332 |
+
|
333 |
+
The sequence starts at 0.
|
334 |
+
"""
|
335 |
+
result = -1
|
336 |
+
if basename != '':
|
337 |
+
basename = basename + "-"
|
338 |
+
|
339 |
+
prefix_length = len(basename)
|
340 |
+
for p in os.listdir(path):
|
341 |
+
if p.startswith(basename):
|
342 |
+
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
343 |
+
try:
|
344 |
+
result = max(int(l[0]), result)
|
345 |
+
except ValueError:
|
346 |
+
pass
|
347 |
+
|
348 |
+
return result + 1
|
349 |
+
|
350 |
+
|
351 |
+
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
352 |
+
'''Save an image.
|
353 |
+
|
354 |
+
Args:
|
355 |
+
image (`PIL.Image`):
|
356 |
+
The image to be saved.
|
357 |
+
path (`str`):
|
358 |
+
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
|
359 |
+
basename (`str`):
|
360 |
+
The base filename which will be applied to `filename pattern`.
|
361 |
+
seed, prompt, short_filename,
|
362 |
+
extension (`str`):
|
363 |
+
Image file extension, default is `png`.
|
364 |
+
pngsectionname (`str`):
|
365 |
+
Specify the name of the section which `info` will be saved in.
|
366 |
+
info (`str` or `PngImagePlugin.iTXt`):
|
367 |
+
PNG info chunks.
|
368 |
+
existing_info (`dict`):
|
369 |
+
Additional PNG info. `existing_info == {pngsectionname: info, ...}`
|
370 |
+
no_prompt:
|
371 |
+
TODO I don't know its meaning.
|
372 |
+
p (`StableDiffusionProcessing`)
|
373 |
+
forced_filename (`str`):
|
374 |
+
If specified, `basename` and filename pattern will be ignored.
|
375 |
+
save_to_dirs (bool):
|
376 |
+
If true, the image will be saved into a subdirectory of `path`.
|
377 |
+
|
378 |
+
Returns: (fullfn, txt_fullfn)
|
379 |
+
fullfn (`str`):
|
380 |
+
The full path of the saved imaged.
|
381 |
+
txt_fullfn (`str` or None):
|
382 |
+
If a text file is saved for this image, this will be its full path. Otherwise None.
|
383 |
+
'''
|
384 |
+
if short_filename or prompt is None or seed is None:
|
385 |
+
file_decoration = ""
|
386 |
+
elif opts.save_to_dirs:
|
387 |
+
file_decoration = opts.samples_filename_pattern or "[seed]"
|
388 |
+
else:
|
389 |
+
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
|
390 |
+
|
391 |
+
if file_decoration != "":
|
392 |
+
file_decoration = "-" + file_decoration.lower()
|
393 |
+
|
394 |
+
file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
|
395 |
+
|
396 |
+
if extension == 'png' and opts.enable_pnginfo and info is not None:
|
397 |
+
pnginfo = PngImagePlugin.PngInfo()
|
398 |
+
|
399 |
+
if existing_info is not None:
|
400 |
+
for k, v in existing_info.items():
|
401 |
+
pnginfo.add_text(k, str(v))
|
402 |
+
|
403 |
+
pnginfo.add_text(pnginfo_section_name, info)
|
404 |
+
else:
|
405 |
+
pnginfo = None
|
406 |
+
|
407 |
+
if save_to_dirs is None:
|
408 |
+
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
409 |
+
|
410 |
+
if save_to_dirs:
|
411 |
+
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
|
412 |
+
path = os.path.join(path, dirname)
|
413 |
+
|
414 |
+
os.makedirs(path, exist_ok=True)
|
415 |
+
|
416 |
+
if forced_filename is None:
|
417 |
+
basecount = get_next_sequence_number(path, basename)
|
418 |
+
fullfn = "a.png"
|
419 |
+
fullfn_without_extension = "a"
|
420 |
+
for i in range(500):
|
421 |
+
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
422 |
+
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
423 |
+
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
|
424 |
+
if not os.path.exists(fullfn):
|
425 |
+
break
|
426 |
+
else:
|
427 |
+
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
|
428 |
+
fullfn_without_extension = os.path.join(path, forced_filename)
|
429 |
+
|
430 |
+
def exif_bytes():
|
431 |
+
return piexif.dump({
|
432 |
+
"Exif": {
|
433 |
+
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
|
434 |
+
},
|
435 |
+
})
|
436 |
+
|
437 |
+
if extension.lower() in ("jpg", "jpeg", "webp"):
|
438 |
+
image.save(fullfn, quality=opts.jpeg_quality)
|
439 |
+
if opts.enable_pnginfo and info is not None:
|
440 |
+
piexif.insert(exif_bytes(), fullfn)
|
441 |
+
else:
|
442 |
+
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
|
443 |
+
|
444 |
+
target_side_length = 4000
|
445 |
+
oversize = image.width > target_side_length or image.height > target_side_length
|
446 |
+
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
|
447 |
+
ratio = image.width / image.height
|
448 |
+
|
449 |
+
if oversize and ratio > 1:
|
450 |
+
image = image.resize((target_side_length, image.height * target_side_length // image.width), LANCZOS)
|
451 |
+
elif oversize:
|
452 |
+
image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
|
453 |
+
|
454 |
+
image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality)
|
455 |
+
if opts.enable_pnginfo and info is not None:
|
456 |
+
piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
|
457 |
+
|
458 |
+
if opts.save_txt and info is not None:
|
459 |
+
txt_fullfn = f"{fullfn_without_extension}.txt"
|
460 |
+
with open(txt_fullfn, "w", encoding="utf8") as file:
|
461 |
+
file.write(info + "\n")
|
462 |
+
else:
|
463 |
+
txt_fullfn = None
|
464 |
+
|
465 |
+
return fullfn, txt_fullfn
|
modules/img2img.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image, ImageOps, ImageChops
|
8 |
+
|
9 |
+
from modules import devices
|
10 |
+
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
11 |
+
from modules.shared import opts, state
|
12 |
+
import modules.shared as shared
|
13 |
+
import modules.processing as processing
|
14 |
+
from modules.ui import plaintext_to_html
|
15 |
+
import modules.images as images
|
16 |
+
import modules.scripts
|
17 |
+
|
18 |
+
|
19 |
+
def process_batch(p, input_dir, output_dir, args):
|
20 |
+
processing.fix_seed(p)
|
21 |
+
|
22 |
+
images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
|
23 |
+
|
24 |
+
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
25 |
+
|
26 |
+
save_normally = output_dir == ''
|
27 |
+
|
28 |
+
p.do_not_save_grid = True
|
29 |
+
p.do_not_save_samples = not save_normally
|
30 |
+
|
31 |
+
state.job_count = len(images) * p.n_iter
|
32 |
+
|
33 |
+
for i, image in enumerate(images):
|
34 |
+
state.job = f"{i+1} out of {len(images)}"
|
35 |
+
if state.skipped:
|
36 |
+
state.skipped = False
|
37 |
+
|
38 |
+
if state.interrupted:
|
39 |
+
break
|
40 |
+
|
41 |
+
img = Image.open(image)
|
42 |
+
p.init_images = [img] * p.batch_size
|
43 |
+
|
44 |
+
proc = modules.scripts.scripts_img2img.run(p, *args)
|
45 |
+
if proc is None:
|
46 |
+
proc = process_images(p)
|
47 |
+
|
48 |
+
for n, processed_image in enumerate(proc.images):
|
49 |
+
filename = os.path.basename(image)
|
50 |
+
|
51 |
+
if n > 0:
|
52 |
+
left, right = os.path.splitext(filename)
|
53 |
+
filename = f"{left}-{n}{right}"
|
54 |
+
|
55 |
+
if not save_normally:
|
56 |
+
processed_image.save(os.path.join(output_dir, filename))
|
57 |
+
|
58 |
+
|
59 |
+
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
60 |
+
is_inpaint = mode == 1
|
61 |
+
is_batch = mode == 2
|
62 |
+
|
63 |
+
if is_inpaint:
|
64 |
+
if mask_mode == 0:
|
65 |
+
image = init_img_with_mask['image']
|
66 |
+
mask = init_img_with_mask['mask']
|
67 |
+
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
68 |
+
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
69 |
+
image = image.convert('RGB')
|
70 |
+
else:
|
71 |
+
image = init_img_inpaint
|
72 |
+
mask = init_mask_inpaint
|
73 |
+
else:
|
74 |
+
image = init_img
|
75 |
+
mask = None
|
76 |
+
|
77 |
+
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
78 |
+
|
79 |
+
p = StableDiffusionProcessingImg2Img(
|
80 |
+
sd_model=shared.sd_model,
|
81 |
+
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
|
82 |
+
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
|
83 |
+
prompt=prompt,
|
84 |
+
negative_prompt=negative_prompt,
|
85 |
+
styles=[prompt_style, prompt_style2],
|
86 |
+
seed=seed,
|
87 |
+
subseed=subseed,
|
88 |
+
subseed_strength=subseed_strength,
|
89 |
+
seed_resize_from_h=seed_resize_from_h,
|
90 |
+
seed_resize_from_w=seed_resize_from_w,
|
91 |
+
seed_enable_extras=seed_enable_extras,
|
92 |
+
sampler_index=sampler_index,
|
93 |
+
batch_size=batch_size,
|
94 |
+
n_iter=n_iter,
|
95 |
+
steps=steps,
|
96 |
+
cfg_scale=cfg_scale,
|
97 |
+
width=width,
|
98 |
+
height=height,
|
99 |
+
restore_faces=restore_faces,
|
100 |
+
tiling=tiling,
|
101 |
+
init_images=[image],
|
102 |
+
mask=mask,
|
103 |
+
mask_blur=mask_blur,
|
104 |
+
inpainting_fill=inpainting_fill,
|
105 |
+
resize_mode=resize_mode,
|
106 |
+
denoising_strength=denoising_strength,
|
107 |
+
inpaint_full_res=inpaint_full_res,
|
108 |
+
inpaint_full_res_padding=inpaint_full_res_padding,
|
109 |
+
inpainting_mask_invert=inpainting_mask_invert,
|
110 |
+
)
|
111 |
+
|
112 |
+
if shared.cmd_opts.enable_console_prompts:
|
113 |
+
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
114 |
+
|
115 |
+
p.extra_generation_params["Mask blur"] = mask_blur
|
116 |
+
|
117 |
+
if is_batch:
|
118 |
+
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
119 |
+
|
120 |
+
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args)
|
121 |
+
|
122 |
+
processed = Processed(p, [], p.seed, "")
|
123 |
+
else:
|
124 |
+
processed = modules.scripts.scripts_img2img.run(p, *args)
|
125 |
+
if processed is None:
|
126 |
+
processed = process_images(p)
|
127 |
+
|
128 |
+
shared.total_tqdm.clear()
|
129 |
+
|
130 |
+
generation_info_js = processed.js()
|
131 |
+
if opts.samples_log_stdout:
|
132 |
+
print(generation_info_js)
|
133 |
+
|
134 |
+
if opts.do_not_show_images:
|
135 |
+
processed.images = []
|
136 |
+
|
137 |
+
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
modules/interrogate.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
from collections import namedtuple
|
6 |
+
import re
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from torchvision import transforms
|
11 |
+
from torchvision.transforms.functional import InterpolationMode
|
12 |
+
|
13 |
+
import modules.shared as shared
|
14 |
+
from modules import devices, paths, lowvram
|
15 |
+
|
16 |
+
blip_image_eval_size = 384
|
17 |
+
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
18 |
+
clip_model_name = 'ViT-L/14'
|
19 |
+
|
20 |
+
Category = namedtuple("Category", ["name", "topn", "items"])
|
21 |
+
|
22 |
+
re_topn = re.compile(r"\.top(\d+)\.")
|
23 |
+
|
24 |
+
|
25 |
+
class InterrogateModels:
|
26 |
+
blip_model = None
|
27 |
+
clip_model = None
|
28 |
+
clip_preprocess = None
|
29 |
+
categories = None
|
30 |
+
dtype = None
|
31 |
+
|
32 |
+
def __init__(self, content_dir):
|
33 |
+
self.categories = []
|
34 |
+
|
35 |
+
if os.path.exists(content_dir):
|
36 |
+
for filename in os.listdir(content_dir):
|
37 |
+
m = re_topn.search(filename)
|
38 |
+
topn = 1 if m is None else int(m.group(1))
|
39 |
+
|
40 |
+
with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
|
41 |
+
lines = [x.strip() for x in file.readlines()]
|
42 |
+
|
43 |
+
self.categories.append(Category(name=filename, topn=topn, items=lines))
|
44 |
+
|
45 |
+
def load_blip_model(self):
|
46 |
+
import models.blip
|
47 |
+
|
48 |
+
blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
49 |
+
blip_model.eval()
|
50 |
+
|
51 |
+
return blip_model
|
52 |
+
|
53 |
+
def load_clip_model(self):
|
54 |
+
import clip
|
55 |
+
|
56 |
+
model, preprocess = clip.load(clip_model_name)
|
57 |
+
model.eval()
|
58 |
+
model = model.to(shared.device)
|
59 |
+
|
60 |
+
return model, preprocess
|
61 |
+
|
62 |
+
def load(self):
|
63 |
+
if self.blip_model is None:
|
64 |
+
self.blip_model = self.load_blip_model()
|
65 |
+
if not shared.cmd_opts.no_half:
|
66 |
+
self.blip_model = self.blip_model.half()
|
67 |
+
|
68 |
+
self.blip_model = self.blip_model.to(shared.device)
|
69 |
+
|
70 |
+
if self.clip_model is None:
|
71 |
+
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
72 |
+
if not shared.cmd_opts.no_half:
|
73 |
+
self.clip_model = self.clip_model.half()
|
74 |
+
|
75 |
+
self.clip_model = self.clip_model.to(shared.device)
|
76 |
+
|
77 |
+
self.dtype = next(self.clip_model.parameters()).dtype
|
78 |
+
|
79 |
+
def send_clip_to_ram(self):
|
80 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
81 |
+
if self.clip_model is not None:
|
82 |
+
self.clip_model = self.clip_model.to(devices.cpu)
|
83 |
+
|
84 |
+
def send_blip_to_ram(self):
|
85 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
86 |
+
if self.blip_model is not None:
|
87 |
+
self.blip_model = self.blip_model.to(devices.cpu)
|
88 |
+
|
89 |
+
def unload(self):
|
90 |
+
self.send_clip_to_ram()
|
91 |
+
self.send_blip_to_ram()
|
92 |
+
|
93 |
+
devices.torch_gc()
|
94 |
+
|
95 |
+
def rank(self, image_features, text_array, top_count=1):
|
96 |
+
import clip
|
97 |
+
|
98 |
+
if shared.opts.interrogate_clip_dict_limit != 0:
|
99 |
+
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
100 |
+
|
101 |
+
top_count = min(top_count, len(text_array))
|
102 |
+
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
|
103 |
+
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
104 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
105 |
+
|
106 |
+
similarity = torch.zeros((1, len(text_array))).to(shared.device)
|
107 |
+
for i in range(image_features.shape[0]):
|
108 |
+
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
109 |
+
similarity /= image_features.shape[0]
|
110 |
+
|
111 |
+
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
112 |
+
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
113 |
+
|
114 |
+
def generate_caption(self, pil_image):
|
115 |
+
gpu_image = transforms.Compose([
|
116 |
+
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
117 |
+
transforms.ToTensor(),
|
118 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
119 |
+
])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
|
120 |
+
|
121 |
+
with torch.no_grad():
|
122 |
+
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
123 |
+
|
124 |
+
return caption[0]
|
125 |
+
|
126 |
+
def interrogate(self, pil_image, include_ranks=False):
|
127 |
+
res = None
|
128 |
+
|
129 |
+
try:
|
130 |
+
|
131 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
132 |
+
lowvram.send_everything_to_cpu()
|
133 |
+
devices.torch_gc()
|
134 |
+
|
135 |
+
self.load()
|
136 |
+
|
137 |
+
caption = self.generate_caption(pil_image)
|
138 |
+
self.send_blip_to_ram()
|
139 |
+
devices.torch_gc()
|
140 |
+
|
141 |
+
res = caption
|
142 |
+
|
143 |
+
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
|
144 |
+
|
145 |
+
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
|
146 |
+
with torch.no_grad(), precision_scope("cuda"):
|
147 |
+
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
148 |
+
|
149 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
150 |
+
|
151 |
+
if shared.opts.interrogate_use_builtin_artists:
|
152 |
+
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
|
153 |
+
|
154 |
+
res += ", " + artist[0]
|
155 |
+
|
156 |
+
for name, topn, items in self.categories:
|
157 |
+
matches = self.rank(image_features, items, top_count=topn)
|
158 |
+
for match, score in matches:
|
159 |
+
if include_ranks:
|
160 |
+
res += ", " + match
|
161 |
+
else:
|
162 |
+
res += f", ({match}:{score})"
|
163 |
+
|
164 |
+
except Exception:
|
165 |
+
print(f"Error interrogating", file=sys.stderr)
|
166 |
+
print(traceback.format_exc(), file=sys.stderr)
|
167 |
+
res += "<error>"
|
168 |
+
|
169 |
+
self.unload()
|
170 |
+
|
171 |
+
return res
|
modules/ldsr_model.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
from basicsr.utils.download_util import load_file_from_url
|
6 |
+
|
7 |
+
from modules.upscaler import Upscaler, UpscalerData
|
8 |
+
from modules.ldsr_model_arch import LDSR
|
9 |
+
from modules import shared
|
10 |
+
|
11 |
+
|
12 |
+
class UpscalerLDSR(Upscaler):
|
13 |
+
def __init__(self, user_path):
|
14 |
+
self.name = "LDSR"
|
15 |
+
self.user_path = user_path
|
16 |
+
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
17 |
+
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
18 |
+
super().__init__()
|
19 |
+
scaler_data = UpscalerData("LDSR", None, self)
|
20 |
+
self.scalers = [scaler_data]
|
21 |
+
|
22 |
+
def load_model(self, path: str):
|
23 |
+
# Remove incorrect project.yaml file if too big
|
24 |
+
yaml_path = os.path.join(self.model_path, "project.yaml")
|
25 |
+
old_model_path = os.path.join(self.model_path, "model.pth")
|
26 |
+
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
27 |
+
if os.path.exists(yaml_path):
|
28 |
+
statinfo = os.stat(yaml_path)
|
29 |
+
if statinfo.st_size >= 10485760:
|
30 |
+
print("Removing invalid LDSR YAML file.")
|
31 |
+
os.remove(yaml_path)
|
32 |
+
if os.path.exists(old_model_path):
|
33 |
+
print("Renaming model from model.pth to model.ckpt")
|
34 |
+
os.rename(old_model_path, new_model_path)
|
35 |
+
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
36 |
+
file_name="model.ckpt", progress=True)
|
37 |
+
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
|
38 |
+
file_name="project.yaml", progress=True)
|
39 |
+
|
40 |
+
try:
|
41 |
+
return LDSR(model, yaml)
|
42 |
+
|
43 |
+
except Exception:
|
44 |
+
print("Error importing LDSR:", file=sys.stderr)
|
45 |
+
print(traceback.format_exc(), file=sys.stderr)
|
46 |
+
return None
|
47 |
+
|
48 |
+
def do_upscale(self, img, path):
|
49 |
+
ldsr = self.load_model(path)
|
50 |
+
if ldsr is None:
|
51 |
+
print("NO LDSR!")
|
52 |
+
return img
|
53 |
+
ddim_steps = shared.opts.ldsr_steps
|
54 |
+
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
modules/ldsr_model_arch.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import time
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
from PIL import Image
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
|
12 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
13 |
+
from ldm.util import instantiate_from_config, ismap
|
14 |
+
|
15 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
16 |
+
|
17 |
+
|
18 |
+
# Create LDSR Class
|
19 |
+
class LDSR:
|
20 |
+
def load_model_from_config(self, half_attention):
|
21 |
+
print(f"Loading model from {self.modelPath}")
|
22 |
+
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
23 |
+
sd = pl_sd["state_dict"]
|
24 |
+
config = OmegaConf.load(self.yamlPath)
|
25 |
+
model = instantiate_from_config(config.model)
|
26 |
+
model.load_state_dict(sd, strict=False)
|
27 |
+
model.cuda()
|
28 |
+
if half_attention:
|
29 |
+
model = model.half()
|
30 |
+
|
31 |
+
model.eval()
|
32 |
+
return {"model": model}
|
33 |
+
|
34 |
+
def __init__(self, model_path, yaml_path):
|
35 |
+
self.modelPath = model_path
|
36 |
+
self.yamlPath = yaml_path
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def run(model, selected_path, custom_steps, eta):
|
40 |
+
example = get_cond(selected_path)
|
41 |
+
|
42 |
+
n_runs = 1
|
43 |
+
guider = None
|
44 |
+
ckwargs = None
|
45 |
+
ddim_use_x0_pred = False
|
46 |
+
temperature = 1.
|
47 |
+
eta = eta
|
48 |
+
custom_shape = None
|
49 |
+
|
50 |
+
height, width = example["image"].shape[1:3]
|
51 |
+
split_input = height >= 128 and width >= 128
|
52 |
+
|
53 |
+
if split_input:
|
54 |
+
ks = 128
|
55 |
+
stride = 64
|
56 |
+
vqf = 4 #
|
57 |
+
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
|
58 |
+
"vqf": vqf,
|
59 |
+
"patch_distributed_vq": True,
|
60 |
+
"tie_braker": False,
|
61 |
+
"clip_max_weight": 0.5,
|
62 |
+
"clip_min_weight": 0.01,
|
63 |
+
"clip_max_tie_weight": 0.5,
|
64 |
+
"clip_min_tie_weight": 0.01}
|
65 |
+
else:
|
66 |
+
if hasattr(model, "split_input_params"):
|
67 |
+
delattr(model, "split_input_params")
|
68 |
+
|
69 |
+
x_t = None
|
70 |
+
logs = None
|
71 |
+
for n in range(n_runs):
|
72 |
+
if custom_shape is not None:
|
73 |
+
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
74 |
+
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
75 |
+
|
76 |
+
logs = make_convolutional_sample(example, model,
|
77 |
+
custom_steps=custom_steps,
|
78 |
+
eta=eta, quantize_x0=False,
|
79 |
+
custom_shape=custom_shape,
|
80 |
+
temperature=temperature, noise_dropout=0.,
|
81 |
+
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
|
82 |
+
ddim_use_x0_pred=ddim_use_x0_pred
|
83 |
+
)
|
84 |
+
return logs
|
85 |
+
|
86 |
+
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
|
87 |
+
model = self.load_model_from_config(half_attention)
|
88 |
+
|
89 |
+
# Run settings
|
90 |
+
diffusion_steps = int(steps)
|
91 |
+
eta = 1.0
|
92 |
+
|
93 |
+
down_sample_method = 'Lanczos'
|
94 |
+
|
95 |
+
gc.collect()
|
96 |
+
torch.cuda.empty_cache()
|
97 |
+
|
98 |
+
im_og = image
|
99 |
+
width_og, height_og = im_og.size
|
100 |
+
# If we can adjust the max upscale size, then the 4 below should be our variable
|
101 |
+
down_sample_rate = target_scale / 4
|
102 |
+
wd = width_og * down_sample_rate
|
103 |
+
hd = height_og * down_sample_rate
|
104 |
+
width_downsampled_pre = int(wd)
|
105 |
+
height_downsampled_pre = int(hd)
|
106 |
+
|
107 |
+
if down_sample_rate != 1:
|
108 |
+
print(
|
109 |
+
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
|
110 |
+
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
111 |
+
else:
|
112 |
+
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
113 |
+
logs = self.run(model["model"], im_og, diffusion_steps, eta)
|
114 |
+
|
115 |
+
sample = logs["sample"]
|
116 |
+
sample = sample.detach().cpu()
|
117 |
+
sample = torch.clamp(sample, -1., 1.)
|
118 |
+
sample = (sample + 1.) / 2. * 255
|
119 |
+
sample = sample.numpy().astype(np.uint8)
|
120 |
+
sample = np.transpose(sample, (0, 2, 3, 1))
|
121 |
+
a = Image.fromarray(sample[0])
|
122 |
+
|
123 |
+
del model
|
124 |
+
gc.collect()
|
125 |
+
torch.cuda.empty_cache()
|
126 |
+
return a
|
127 |
+
|
128 |
+
|
129 |
+
def get_cond(selected_path):
|
130 |
+
example = dict()
|
131 |
+
up_f = 4
|
132 |
+
c = selected_path.convert('RGB')
|
133 |
+
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
134 |
+
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
|
135 |
+
antialias=True)
|
136 |
+
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
|
137 |
+
c = rearrange(c, '1 c h w -> 1 h w c')
|
138 |
+
c = 2. * c - 1.
|
139 |
+
|
140 |
+
c = c.to(torch.device("cuda"))
|
141 |
+
example["LR_image"] = c
|
142 |
+
example["image"] = c_up
|
143 |
+
|
144 |
+
return example
|
145 |
+
|
146 |
+
|
147 |
+
@torch.no_grad()
|
148 |
+
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
|
149 |
+
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
|
150 |
+
corrector_kwargs=None, x_t=None
|
151 |
+
):
|
152 |
+
ddim = DDIMSampler(model)
|
153 |
+
bs = shape[0]
|
154 |
+
shape = shape[1:]
|
155 |
+
print(f"Sampling with eta = {eta}; steps: {steps}")
|
156 |
+
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
|
157 |
+
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
|
158 |
+
mask=mask, x0=x0, temperature=temperature, verbose=False,
|
159 |
+
score_corrector=score_corrector,
|
160 |
+
corrector_kwargs=corrector_kwargs, x_t=x_t)
|
161 |
+
|
162 |
+
return samples, intermediates
|
163 |
+
|
164 |
+
|
165 |
+
@torch.no_grad()
|
166 |
+
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
167 |
+
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
168 |
+
log = dict()
|
169 |
+
|
170 |
+
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
171 |
+
return_first_stage_outputs=True,
|
172 |
+
force_c_encode=not (hasattr(model, 'split_input_params')
|
173 |
+
and model.cond_stage_key == 'coordinates_bbox'),
|
174 |
+
return_original_cond=True)
|
175 |
+
|
176 |
+
if custom_shape is not None:
|
177 |
+
z = torch.randn(custom_shape)
|
178 |
+
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
179 |
+
|
180 |
+
z0 = None
|
181 |
+
|
182 |
+
log["input"] = x
|
183 |
+
log["reconstruction"] = xrec
|
184 |
+
|
185 |
+
if ismap(xc):
|
186 |
+
log["original_conditioning"] = model.to_rgb(xc)
|
187 |
+
if hasattr(model, 'cond_stage_key'):
|
188 |
+
log[model.cond_stage_key] = model.to_rgb(xc)
|
189 |
+
|
190 |
+
else:
|
191 |
+
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
192 |
+
if model.cond_stage_model:
|
193 |
+
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
194 |
+
if model.cond_stage_key == 'class_label':
|
195 |
+
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
196 |
+
|
197 |
+
with model.ema_scope("Plotting"):
|
198 |
+
t0 = time.time()
|
199 |
+
|
200 |
+
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
|
201 |
+
eta=eta,
|
202 |
+
quantize_x0=quantize_x0, mask=None, x0=z0,
|
203 |
+
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
|
204 |
+
x_t=x_T)
|
205 |
+
t1 = time.time()
|
206 |
+
|
207 |
+
if ddim_use_x0_pred:
|
208 |
+
sample = intermediates['pred_x0'][-1]
|
209 |
+
|
210 |
+
x_sample = model.decode_first_stage(sample)
|
211 |
+
|
212 |
+
try:
|
213 |
+
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
214 |
+
log["sample_noquant"] = x_sample_noquant
|
215 |
+
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
216 |
+
except:
|
217 |
+
pass
|
218 |
+
|
219 |
+
log["sample"] = x_sample
|
220 |
+
log["time"] = t1 - t0
|
221 |
+
|
222 |
+
return log
|
modules/lowvram.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from modules.devices import get_optimal_device
|
3 |
+
|
4 |
+
module_in_gpu = None
|
5 |
+
cpu = torch.device("cpu")
|
6 |
+
device = gpu = get_optimal_device()
|
7 |
+
|
8 |
+
|
9 |
+
def send_everything_to_cpu():
|
10 |
+
global module_in_gpu
|
11 |
+
|
12 |
+
if module_in_gpu is not None:
|
13 |
+
module_in_gpu.to(cpu)
|
14 |
+
|
15 |
+
module_in_gpu = None
|
16 |
+
|
17 |
+
|
18 |
+
def setup_for_low_vram(sd_model, use_medvram):
|
19 |
+
parents = {}
|
20 |
+
|
21 |
+
def send_me_to_gpu(module, _):
|
22 |
+
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
|
23 |
+
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
|
24 |
+
be in CPU
|
25 |
+
"""
|
26 |
+
global module_in_gpu
|
27 |
+
|
28 |
+
module = parents.get(module, module)
|
29 |
+
|
30 |
+
if module_in_gpu == module:
|
31 |
+
return
|
32 |
+
|
33 |
+
if module_in_gpu is not None:
|
34 |
+
module_in_gpu.to(cpu)
|
35 |
+
|
36 |
+
module.to(gpu)
|
37 |
+
module_in_gpu = module
|
38 |
+
|
39 |
+
# see below for register_forward_pre_hook;
|
40 |
+
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
41 |
+
# useless here, and we just replace those methods
|
42 |
+
def first_stage_model_encode_wrap(self, encoder, x):
|
43 |
+
send_me_to_gpu(self, None)
|
44 |
+
return encoder(x)
|
45 |
+
|
46 |
+
def first_stage_model_decode_wrap(self, decoder, z):
|
47 |
+
send_me_to_gpu(self, None)
|
48 |
+
return decoder(z)
|
49 |
+
|
50 |
+
# remove three big modules, cond, first_stage, and unet from the model and then
|
51 |
+
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
52 |
+
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
|
53 |
+
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
|
54 |
+
sd_model.to(device)
|
55 |
+
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
|
56 |
+
|
57 |
+
# register hooks for those the first two models
|
58 |
+
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
59 |
+
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
60 |
+
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
|
61 |
+
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
|
62 |
+
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
63 |
+
|
64 |
+
if use_medvram:
|
65 |
+
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
66 |
+
else:
|
67 |
+
diff_model = sd_model.model.diffusion_model
|
68 |
+
|
69 |
+
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
|
70 |
+
# so that only one of them is in GPU at a time
|
71 |
+
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
72 |
+
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
73 |
+
sd_model.model.to(device)
|
74 |
+
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
75 |
+
|
76 |
+
# install hooks for bits of third model
|
77 |
+
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
78 |
+
for block in diff_model.input_blocks:
|
79 |
+
block.register_forward_pre_hook(send_me_to_gpu)
|
80 |
+
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
|
81 |
+
for block in diff_model.output_blocks:
|
82 |
+
block.register_forward_pre_hook(send_me_to_gpu)
|
modules/masking.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageFilter, ImageOps
|
2 |
+
|
3 |
+
|
4 |
+
def get_crop_region(mask, pad=0):
|
5 |
+
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
6 |
+
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
7 |
+
|
8 |
+
h, w = mask.shape
|
9 |
+
|
10 |
+
crop_left = 0
|
11 |
+
for i in range(w):
|
12 |
+
if not (mask[:, i] == 0).all():
|
13 |
+
break
|
14 |
+
crop_left += 1
|
15 |
+
|
16 |
+
crop_right = 0
|
17 |
+
for i in reversed(range(w)):
|
18 |
+
if not (mask[:, i] == 0).all():
|
19 |
+
break
|
20 |
+
crop_right += 1
|
21 |
+
|
22 |
+
crop_top = 0
|
23 |
+
for i in range(h):
|
24 |
+
if not (mask[i] == 0).all():
|
25 |
+
break
|
26 |
+
crop_top += 1
|
27 |
+
|
28 |
+
crop_bottom = 0
|
29 |
+
for i in reversed(range(h)):
|
30 |
+
if not (mask[i] == 0).all():
|
31 |
+
break
|
32 |
+
crop_bottom += 1
|
33 |
+
|
34 |
+
return (
|
35 |
+
int(max(crop_left-pad, 0)),
|
36 |
+
int(max(crop_top-pad, 0)),
|
37 |
+
int(min(w - crop_right + pad, w)),
|
38 |
+
int(min(h - crop_bottom + pad, h))
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
|
43 |
+
"""expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
|
44 |
+
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
|
45 |
+
|
46 |
+
x1, y1, x2, y2 = crop_region
|
47 |
+
|
48 |
+
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
49 |
+
ratio_processing = processing_width / processing_height
|
50 |
+
|
51 |
+
if ratio_crop_region > ratio_processing:
|
52 |
+
desired_height = (x2 - x1) * ratio_processing
|
53 |
+
desired_height_diff = int(desired_height - (y2-y1))
|
54 |
+
y1 -= desired_height_diff//2
|
55 |
+
y2 += desired_height_diff - desired_height_diff//2
|
56 |
+
if y2 >= image_height:
|
57 |
+
diff = y2 - image_height
|
58 |
+
y2 -= diff
|
59 |
+
y1 -= diff
|
60 |
+
if y1 < 0:
|
61 |
+
y2 -= y1
|
62 |
+
y1 -= y1
|
63 |
+
if y2 >= image_height:
|
64 |
+
y2 = image_height
|
65 |
+
else:
|
66 |
+
desired_width = (y2 - y1) * ratio_processing
|
67 |
+
desired_width_diff = int(desired_width - (x2-x1))
|
68 |
+
x1 -= desired_width_diff//2
|
69 |
+
x2 += desired_width_diff - desired_width_diff//2
|
70 |
+
if x2 >= image_width:
|
71 |
+
diff = x2 - image_width
|
72 |
+
x2 -= diff
|
73 |
+
x1 -= diff
|
74 |
+
if x1 < 0:
|
75 |
+
x2 -= x1
|
76 |
+
x1 -= x1
|
77 |
+
if x2 >= image_width:
|
78 |
+
x2 = image_width
|
79 |
+
|
80 |
+
return x1, y1, x2, y2
|
81 |
+
|
82 |
+
|
83 |
+
def fill(image, mask):
|
84 |
+
"""fills masked regions with colors from image using blur. Not extremely effective."""
|
85 |
+
|
86 |
+
image_mod = Image.new('RGBA', (image.width, image.height))
|
87 |
+
|
88 |
+
image_masked = Image.new('RGBa', (image.width, image.height))
|
89 |
+
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
|
90 |
+
|
91 |
+
image_masked = image_masked.convert('RGBa')
|
92 |
+
|
93 |
+
for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
|
94 |
+
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
|
95 |
+
for _ in range(repeats):
|
96 |
+
image_mod.alpha_composite(blurred)
|
97 |
+
|
98 |
+
return image_mod.convert("RGB")
|
99 |
+
|
modules/memmon.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import time
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class MemUsageMonitor(threading.Thread):
|
9 |
+
run_flag = None
|
10 |
+
device = None
|
11 |
+
disabled = False
|
12 |
+
opts = None
|
13 |
+
data = None
|
14 |
+
|
15 |
+
def __init__(self, name, device, opts):
|
16 |
+
threading.Thread.__init__(self)
|
17 |
+
self.name = name
|
18 |
+
self.device = device
|
19 |
+
self.opts = opts
|
20 |
+
|
21 |
+
self.daemon = True
|
22 |
+
self.run_flag = threading.Event()
|
23 |
+
self.data = defaultdict(int)
|
24 |
+
|
25 |
+
try:
|
26 |
+
torch.cuda.mem_get_info()
|
27 |
+
torch.cuda.memory_stats(self.device)
|
28 |
+
except Exception as e: # AMD or whatever
|
29 |
+
print(f"Warning: caught exception '{e}', memory monitor disabled")
|
30 |
+
self.disabled = True
|
31 |
+
|
32 |
+
def run(self):
|
33 |
+
if self.disabled:
|
34 |
+
return
|
35 |
+
|
36 |
+
while True:
|
37 |
+
self.run_flag.wait()
|
38 |
+
|
39 |
+
torch.cuda.reset_peak_memory_stats()
|
40 |
+
self.data.clear()
|
41 |
+
|
42 |
+
if self.opts.memmon_poll_rate <= 0:
|
43 |
+
self.run_flag.clear()
|
44 |
+
continue
|
45 |
+
|
46 |
+
self.data["min_free"] = torch.cuda.mem_get_info()[0]
|
47 |
+
|
48 |
+
while self.run_flag.is_set():
|
49 |
+
free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
|
50 |
+
self.data["min_free"] = min(self.data["min_free"], free)
|
51 |
+
|
52 |
+
time.sleep(1 / self.opts.memmon_poll_rate)
|
53 |
+
|
54 |
+
def dump_debug(self):
|
55 |
+
print(self, 'recorded data:')
|
56 |
+
for k, v in self.read().items():
|
57 |
+
print(k, -(v // -(1024 ** 2)))
|
58 |
+
|
59 |
+
print(self, 'raw torch memory stats:')
|
60 |
+
tm = torch.cuda.memory_stats(self.device)
|
61 |
+
for k, v in tm.items():
|
62 |
+
if 'bytes' not in k:
|
63 |
+
continue
|
64 |
+
print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
|
65 |
+
|
66 |
+
print(torch.cuda.memory_summary())
|
67 |
+
|
68 |
+
def monitor(self):
|
69 |
+
self.run_flag.set()
|
70 |
+
|
71 |
+
def read(self):
|
72 |
+
if not self.disabled:
|
73 |
+
free, total = torch.cuda.mem_get_info()
|
74 |
+
self.data["total"] = total
|
75 |
+
|
76 |
+
torch_stats = torch.cuda.memory_stats(self.device)
|
77 |
+
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
78 |
+
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
79 |
+
self.data["system_peak"] = total - self.data["min_free"]
|
80 |
+
|
81 |
+
return self.data
|
82 |
+
|
83 |
+
def stop(self):
|
84 |
+
self.run_flag.clear()
|
85 |
+
return self.read()
|
modules/modelloader.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import importlib
|
5 |
+
from urllib.parse import urlparse
|
6 |
+
|
7 |
+
from basicsr.utils.download_util import load_file_from_url
|
8 |
+
from modules import shared
|
9 |
+
from modules.upscaler import Upscaler
|
10 |
+
from modules.paths import script_path, models_path
|
11 |
+
|
12 |
+
|
13 |
+
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
|
14 |
+
"""
|
15 |
+
A one-and done loader to try finding the desired models in specified directories.
|
16 |
+
|
17 |
+
@param download_name: Specify to download from model_url immediately.
|
18 |
+
@param model_url: If no other models are found, this will be downloaded on upscale.
|
19 |
+
@param model_path: The location to store/find models in.
|
20 |
+
@param command_path: A command-line argument to search for models in first.
|
21 |
+
@param ext_filter: An optional list of filename extensions to filter by
|
22 |
+
@return: A list of paths containing the desired model(s)
|
23 |
+
"""
|
24 |
+
output = []
|
25 |
+
|
26 |
+
if ext_filter is None:
|
27 |
+
ext_filter = []
|
28 |
+
|
29 |
+
try:
|
30 |
+
places = []
|
31 |
+
|
32 |
+
if command_path is not None and command_path != model_path:
|
33 |
+
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
|
34 |
+
if os.path.exists(pretrained_path):
|
35 |
+
print(f"Appending path: {pretrained_path}")
|
36 |
+
places.append(pretrained_path)
|
37 |
+
elif os.path.exists(command_path):
|
38 |
+
places.append(command_path)
|
39 |
+
|
40 |
+
places.append(model_path)
|
41 |
+
|
42 |
+
for place in places:
|
43 |
+
if os.path.exists(place):
|
44 |
+
for file in glob.iglob(place + '**/**', recursive=True):
|
45 |
+
full_path = file
|
46 |
+
if os.path.isdir(full_path):
|
47 |
+
continue
|
48 |
+
if len(ext_filter) != 0:
|
49 |
+
model_name, extension = os.path.splitext(file)
|
50 |
+
if extension not in ext_filter:
|
51 |
+
continue
|
52 |
+
if file not in output:
|
53 |
+
output.append(full_path)
|
54 |
+
|
55 |
+
if model_url is not None and len(output) == 0:
|
56 |
+
if download_name is not None:
|
57 |
+
dl = load_file_from_url(model_url, model_path, True, download_name)
|
58 |
+
output.append(dl)
|
59 |
+
else:
|
60 |
+
output.append(model_url)
|
61 |
+
|
62 |
+
except Exception:
|
63 |
+
pass
|
64 |
+
|
65 |
+
return output
|
66 |
+
|
67 |
+
|
68 |
+
def friendly_name(file: str):
|
69 |
+
if "http" in file:
|
70 |
+
file = urlparse(file).path
|
71 |
+
|
72 |
+
file = os.path.basename(file)
|
73 |
+
model_name, extension = os.path.splitext(file)
|
74 |
+
return model_name
|
75 |
+
|
76 |
+
|
77 |
+
def cleanup_models():
|
78 |
+
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
|
79 |
+
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
|
80 |
+
# somehow auto-register and just do these things...
|
81 |
+
root_path = script_path
|
82 |
+
src_path = models_path
|
83 |
+
dest_path = os.path.join(models_path, "Stable-diffusion")
|
84 |
+
move_files(src_path, dest_path, ".ckpt")
|
85 |
+
src_path = os.path.join(root_path, "ESRGAN")
|
86 |
+
dest_path = os.path.join(models_path, "ESRGAN")
|
87 |
+
move_files(src_path, dest_path)
|
88 |
+
src_path = os.path.join(root_path, "gfpgan")
|
89 |
+
dest_path = os.path.join(models_path, "GFPGAN")
|
90 |
+
move_files(src_path, dest_path)
|
91 |
+
src_path = os.path.join(root_path, "SwinIR")
|
92 |
+
dest_path = os.path.join(models_path, "SwinIR")
|
93 |
+
move_files(src_path, dest_path)
|
94 |
+
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
|
95 |
+
dest_path = os.path.join(models_path, "LDSR")
|
96 |
+
move_files(src_path, dest_path)
|
97 |
+
|
98 |
+
|
99 |
+
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
100 |
+
try:
|
101 |
+
if not os.path.exists(dest_path):
|
102 |
+
os.makedirs(dest_path)
|
103 |
+
if os.path.exists(src_path):
|
104 |
+
for file in os.listdir(src_path):
|
105 |
+
fullpath = os.path.join(src_path, file)
|
106 |
+
if os.path.isfile(fullpath):
|
107 |
+
if ext_filter is not None:
|
108 |
+
if ext_filter not in file:
|
109 |
+
continue
|
110 |
+
print(f"Moving {file} from {src_path} to {dest_path}.")
|
111 |
+
try:
|
112 |
+
shutil.move(fullpath, dest_path)
|
113 |
+
except:
|
114 |
+
pass
|
115 |
+
if len(os.listdir(src_path)) == 0:
|
116 |
+
print(f"Removing empty folder: {src_path}")
|
117 |
+
shutil.rmtree(src_path, True)
|
118 |
+
except:
|
119 |
+
pass
|
120 |
+
|
121 |
+
|
122 |
+
def load_upscalers():
|
123 |
+
sd = shared.script_path
|
124 |
+
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
125 |
+
# so we'll try to import any _model.py files before looking in __subclasses__
|
126 |
+
modules_dir = os.path.join(sd, "modules")
|
127 |
+
for file in os.listdir(modules_dir):
|
128 |
+
if "_model.py" in file:
|
129 |
+
model_name = file.replace("_model.py", "")
|
130 |
+
full_model = f"modules.{model_name}_model"
|
131 |
+
try:
|
132 |
+
importlib.import_module(full_model)
|
133 |
+
except:
|
134 |
+
pass
|
135 |
+
datas = []
|
136 |
+
c_o = vars(shared.cmd_opts)
|
137 |
+
for cls in Upscaler.__subclasses__():
|
138 |
+
name = cls.__name__
|
139 |
+
module_name = cls.__module__
|
140 |
+
module = importlib.import_module(module_name)
|
141 |
+
class_ = getattr(module, name)
|
142 |
+
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
143 |
+
opt_string = None
|
144 |
+
try:
|
145 |
+
if cmd_name in c_o:
|
146 |
+
opt_string = c_o[cmd_name]
|
147 |
+
except:
|
148 |
+
pass
|
149 |
+
scaler = class_(opt_string)
|
150 |
+
for child in scaler.scalers:
|
151 |
+
datas.append(child)
|
152 |
+
|
153 |
+
shared.sd_upscalers = datas
|
modules/ngrok.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pyngrok import ngrok, conf, exception
|
2 |
+
|
3 |
+
|
4 |
+
def connect(token, port):
|
5 |
+
if token == None:
|
6 |
+
token = 'None'
|
7 |
+
conf.get_default().auth_token = token
|
8 |
+
try:
|
9 |
+
public_url = ngrok.connect(port).public_url
|
10 |
+
except exception.PyngrokNgrokError:
|
11 |
+
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
|
12 |
+
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
|
13 |
+
else:
|
14 |
+
print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
|
15 |
+
'You can use this link after the launch is complete.')
|
modules/paths.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import modules.safe
|
5 |
+
|
6 |
+
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
7 |
+
models_path = os.path.join(script_path, "models")
|
8 |
+
sys.path.insert(0, script_path)
|
9 |
+
|
10 |
+
# search for directory of stable diffusion in following places
|
11 |
+
sd_path = None
|
12 |
+
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
|
13 |
+
for possible_sd_path in possible_sd_paths:
|
14 |
+
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
15 |
+
sd_path = os.path.abspath(possible_sd_path)
|
16 |
+
break
|
17 |
+
|
18 |
+
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
|
19 |
+
|
20 |
+
path_dirs = [
|
21 |
+
(sd_path, 'ldm', 'Stable Diffusion', []),
|
22 |
+
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
|
23 |
+
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
24 |
+
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
25 |
+
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
26 |
+
]
|
27 |
+
|
28 |
+
paths = {}
|
29 |
+
|
30 |
+
for d, must_exist, what, options in path_dirs:
|
31 |
+
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
|
32 |
+
if not os.path.exists(must_exist_path):
|
33 |
+
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
|
34 |
+
else:
|
35 |
+
d = os.path.abspath(d)
|
36 |
+
if "atstart" in options:
|
37 |
+
sys.path.insert(0, d)
|
38 |
+
else:
|
39 |
+
sys.path.append(d)
|
40 |
+
paths[what] = d
|
modules/processing.py
ADDED
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image, ImageFilter, ImageOps
|
9 |
+
import random
|
10 |
+
import cv2
|
11 |
+
from skimage import exposure
|
12 |
+
|
13 |
+
import modules.sd_hijack
|
14 |
+
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
15 |
+
from modules.sd_hijack import model_hijack
|
16 |
+
from modules.shared import opts, cmd_opts, state
|
17 |
+
import modules.shared as shared
|
18 |
+
import modules.face_restoration
|
19 |
+
import modules.images as images
|
20 |
+
import modules.styles
|
21 |
+
import logging
|
22 |
+
|
23 |
+
|
24 |
+
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
25 |
+
opt_C = 4
|
26 |
+
opt_f = 8
|
27 |
+
|
28 |
+
|
29 |
+
def setup_color_correction(image):
|
30 |
+
logging.info("Calibrating color correction.")
|
31 |
+
correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
|
32 |
+
return correction_target
|
33 |
+
|
34 |
+
|
35 |
+
def apply_color_correction(correction, image):
|
36 |
+
logging.info("Applying color correction.")
|
37 |
+
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
38 |
+
cv2.cvtColor(
|
39 |
+
np.asarray(image),
|
40 |
+
cv2.COLOR_RGB2LAB
|
41 |
+
),
|
42 |
+
correction,
|
43 |
+
channel_axis=2
|
44 |
+
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
45 |
+
|
46 |
+
return image
|
47 |
+
|
48 |
+
|
49 |
+
def get_correct_sampler(p):
|
50 |
+
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
|
51 |
+
return sd_samplers.samplers
|
52 |
+
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
53 |
+
return sd_samplers.samplers_for_img2img
|
54 |
+
|
55 |
+
class StableDiffusionProcessing:
|
56 |
+
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
|
57 |
+
self.sd_model = sd_model
|
58 |
+
self.outpath_samples: str = outpath_samples
|
59 |
+
self.outpath_grids: str = outpath_grids
|
60 |
+
self.prompt: str = prompt
|
61 |
+
self.prompt_for_display: str = None
|
62 |
+
self.negative_prompt: str = (negative_prompt or "")
|
63 |
+
self.styles: list = styles or []
|
64 |
+
self.seed: int = seed
|
65 |
+
self.subseed: int = subseed
|
66 |
+
self.subseed_strength: float = subseed_strength
|
67 |
+
self.seed_resize_from_h: int = seed_resize_from_h
|
68 |
+
self.seed_resize_from_w: int = seed_resize_from_w
|
69 |
+
self.sampler_index: int = sampler_index
|
70 |
+
self.batch_size: int = batch_size
|
71 |
+
self.n_iter: int = n_iter
|
72 |
+
self.steps: int = steps
|
73 |
+
self.cfg_scale: float = cfg_scale
|
74 |
+
self.width: int = width
|
75 |
+
self.height: int = height
|
76 |
+
self.restore_faces: bool = restore_faces
|
77 |
+
self.tiling: bool = tiling
|
78 |
+
self.do_not_save_samples: bool = do_not_save_samples
|
79 |
+
self.do_not_save_grid: bool = do_not_save_grid
|
80 |
+
self.extra_generation_params: dict = extra_generation_params or {}
|
81 |
+
self.overlay_images = overlay_images
|
82 |
+
self.eta = eta
|
83 |
+
self.paste_to = None
|
84 |
+
self.color_corrections = None
|
85 |
+
self.denoising_strength: float = 0
|
86 |
+
self.sampler_noise_scheduler_override = None
|
87 |
+
self.ddim_discretize = opts.ddim_discretize
|
88 |
+
self.s_churn = opts.s_churn
|
89 |
+
self.s_tmin = opts.s_tmin
|
90 |
+
self.s_tmax = float('inf') # not representable as a standard ui option
|
91 |
+
self.s_noise = opts.s_noise
|
92 |
+
|
93 |
+
if not seed_enable_extras:
|
94 |
+
self.subseed = -1
|
95 |
+
self.subseed_strength = 0
|
96 |
+
self.seed_resize_from_h = 0
|
97 |
+
self.seed_resize_from_w = 0
|
98 |
+
|
99 |
+
def init(self, all_prompts, all_seeds, all_subseeds):
|
100 |
+
pass
|
101 |
+
|
102 |
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
103 |
+
raise NotImplementedError()
|
104 |
+
|
105 |
+
|
106 |
+
class Processed:
|
107 |
+
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
108 |
+
self.images = images_list
|
109 |
+
self.prompt = p.prompt
|
110 |
+
self.negative_prompt = p.negative_prompt
|
111 |
+
self.seed = seed
|
112 |
+
self.subseed = subseed
|
113 |
+
self.subseed_strength = p.subseed_strength
|
114 |
+
self.info = info
|
115 |
+
self.width = p.width
|
116 |
+
self.height = p.height
|
117 |
+
self.sampler_index = p.sampler_index
|
118 |
+
self.sampler = sd_samplers.samplers[p.sampler_index].name
|
119 |
+
self.cfg_scale = p.cfg_scale
|
120 |
+
self.steps = p.steps
|
121 |
+
self.batch_size = p.batch_size
|
122 |
+
self.restore_faces = p.restore_faces
|
123 |
+
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
124 |
+
self.sd_model_hash = shared.sd_model.sd_model_hash
|
125 |
+
self.seed_resize_from_w = p.seed_resize_from_w
|
126 |
+
self.seed_resize_from_h = p.seed_resize_from_h
|
127 |
+
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
128 |
+
self.extra_generation_params = p.extra_generation_params
|
129 |
+
self.index_of_first_image = index_of_first_image
|
130 |
+
self.styles = p.styles
|
131 |
+
self.job_timestamp = state.job_timestamp
|
132 |
+
self.clip_skip = opts.CLIP_stop_at_last_layers
|
133 |
+
|
134 |
+
self.eta = p.eta
|
135 |
+
self.ddim_discretize = p.ddim_discretize
|
136 |
+
self.s_churn = p.s_churn
|
137 |
+
self.s_tmin = p.s_tmin
|
138 |
+
self.s_tmax = p.s_tmax
|
139 |
+
self.s_noise = p.s_noise
|
140 |
+
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
141 |
+
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
142 |
+
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
143 |
+
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
|
144 |
+
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
145 |
+
|
146 |
+
self.all_prompts = all_prompts or [self.prompt]
|
147 |
+
self.all_seeds = all_seeds or [self.seed]
|
148 |
+
self.all_subseeds = all_subseeds or [self.subseed]
|
149 |
+
self.infotexts = infotexts or [info]
|
150 |
+
|
151 |
+
def js(self):
|
152 |
+
obj = {
|
153 |
+
"prompt": self.prompt,
|
154 |
+
"all_prompts": self.all_prompts,
|
155 |
+
"negative_prompt": self.negative_prompt,
|
156 |
+
"seed": self.seed,
|
157 |
+
"all_seeds": self.all_seeds,
|
158 |
+
"subseed": self.subseed,
|
159 |
+
"all_subseeds": self.all_subseeds,
|
160 |
+
"subseed_strength": self.subseed_strength,
|
161 |
+
"width": self.width,
|
162 |
+
"height": self.height,
|
163 |
+
"sampler_index": self.sampler_index,
|
164 |
+
"sampler": self.sampler,
|
165 |
+
"cfg_scale": self.cfg_scale,
|
166 |
+
"steps": self.steps,
|
167 |
+
"batch_size": self.batch_size,
|
168 |
+
"restore_faces": self.restore_faces,
|
169 |
+
"face_restoration_model": self.face_restoration_model,
|
170 |
+
"sd_model_hash": self.sd_model_hash,
|
171 |
+
"seed_resize_from_w": self.seed_resize_from_w,
|
172 |
+
"seed_resize_from_h": self.seed_resize_from_h,
|
173 |
+
"denoising_strength": self.denoising_strength,
|
174 |
+
"extra_generation_params": self.extra_generation_params,
|
175 |
+
"index_of_first_image": self.index_of_first_image,
|
176 |
+
"infotexts": self.infotexts,
|
177 |
+
"styles": self.styles,
|
178 |
+
"job_timestamp": self.job_timestamp,
|
179 |
+
"clip_skip": self.clip_skip,
|
180 |
+
}
|
181 |
+
|
182 |
+
return json.dumps(obj)
|
183 |
+
|
184 |
+
def infotext(self, p: StableDiffusionProcessing, index):
|
185 |
+
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
186 |
+
|
187 |
+
|
188 |
+
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
189 |
+
def slerp(val, low, high):
|
190 |
+
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
191 |
+
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
192 |
+
dot = (low_norm*high_norm).sum(1)
|
193 |
+
|
194 |
+
if dot.mean() > 0.9995:
|
195 |
+
return low * val + high * (1 - val)
|
196 |
+
|
197 |
+
omega = torch.acos(dot)
|
198 |
+
so = torch.sin(omega)
|
199 |
+
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
200 |
+
return res
|
201 |
+
|
202 |
+
|
203 |
+
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
204 |
+
xs = []
|
205 |
+
|
206 |
+
# if we have multiple seeds, this means we are working with batch size>1; this then
|
207 |
+
# enables the generation of additional tensors with noise that the sampler will use during its processing.
|
208 |
+
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
|
209 |
+
# produce the same images as with two batches [100], [101].
|
210 |
+
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
|
211 |
+
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
212 |
+
else:
|
213 |
+
sampler_noises = None
|
214 |
+
|
215 |
+
for i, seed in enumerate(seeds):
|
216 |
+
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
217 |
+
|
218 |
+
subnoise = None
|
219 |
+
if subseeds is not None:
|
220 |
+
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
221 |
+
|
222 |
+
subnoise = devices.randn(subseed, noise_shape)
|
223 |
+
|
224 |
+
# randn results depend on device; gpu and cpu get different results for same seed;
|
225 |
+
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
226 |
+
# but the original script had it like this, so I do not dare change it for now because
|
227 |
+
# it will break everyone's seeds.
|
228 |
+
noise = devices.randn(seed, noise_shape)
|
229 |
+
|
230 |
+
if subnoise is not None:
|
231 |
+
noise = slerp(subseed_strength, noise, subnoise)
|
232 |
+
|
233 |
+
if noise_shape != shape:
|
234 |
+
x = devices.randn(seed, shape)
|
235 |
+
dx = (shape[2] - noise_shape[2]) // 2
|
236 |
+
dy = (shape[1] - noise_shape[1]) // 2
|
237 |
+
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
238 |
+
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
239 |
+
tx = 0 if dx < 0 else dx
|
240 |
+
ty = 0 if dy < 0 else dy
|
241 |
+
dx = max(-dx, 0)
|
242 |
+
dy = max(-dy, 0)
|
243 |
+
|
244 |
+
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
|
245 |
+
noise = x
|
246 |
+
|
247 |
+
if sampler_noises is not None:
|
248 |
+
cnt = p.sampler.number_of_needed_noises(p)
|
249 |
+
|
250 |
+
if opts.eta_noise_seed_delta > 0:
|
251 |
+
torch.manual_seed(seed + opts.eta_noise_seed_delta)
|
252 |
+
|
253 |
+
for j in range(cnt):
|
254 |
+
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
255 |
+
|
256 |
+
xs.append(noise)
|
257 |
+
|
258 |
+
if sampler_noises is not None:
|
259 |
+
p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
|
260 |
+
|
261 |
+
x = torch.stack(xs).to(shared.device)
|
262 |
+
return x
|
263 |
+
|
264 |
+
|
265 |
+
def decode_first_stage(model, x):
|
266 |
+
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
267 |
+
x = model.decode_first_stage(x)
|
268 |
+
|
269 |
+
return x
|
270 |
+
|
271 |
+
|
272 |
+
def get_fixed_seed(seed):
|
273 |
+
if seed is None or seed == '' or seed == -1:
|
274 |
+
return int(random.randrange(4294967294))
|
275 |
+
|
276 |
+
return seed
|
277 |
+
|
278 |
+
|
279 |
+
def fix_seed(p):
|
280 |
+
p.seed = get_fixed_seed(p.seed)
|
281 |
+
p.subseed = get_fixed_seed(p.subseed)
|
282 |
+
|
283 |
+
|
284 |
+
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
285 |
+
index = position_in_batch + iteration * p.batch_size
|
286 |
+
|
287 |
+
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
288 |
+
|
289 |
+
generation_params = {
|
290 |
+
"Steps": p.steps,
|
291 |
+
"Sampler": get_correct_sampler(p)[p.sampler_index].name,
|
292 |
+
"CFG scale": p.cfg_scale,
|
293 |
+
"Seed": all_seeds[index],
|
294 |
+
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
295 |
+
"Size": f"{p.width}x{p.height}",
|
296 |
+
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
297 |
+
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
298 |
+
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
|
299 |
+
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
300 |
+
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
301 |
+
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
302 |
+
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
303 |
+
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
304 |
+
"Denoising strength": getattr(p, 'denoising_strength', None),
|
305 |
+
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
306 |
+
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
307 |
+
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
308 |
+
}
|
309 |
+
|
310 |
+
generation_params.update(p.extra_generation_params)
|
311 |
+
|
312 |
+
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
|
313 |
+
|
314 |
+
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
315 |
+
|
316 |
+
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
317 |
+
|
318 |
+
|
319 |
+
def process_images(p: StableDiffusionProcessing) -> Processed:
|
320 |
+
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
321 |
+
|
322 |
+
if type(p.prompt) == list:
|
323 |
+
assert(len(p.prompt) > 0)
|
324 |
+
else:
|
325 |
+
assert p.prompt is not None
|
326 |
+
|
327 |
+
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
328 |
+
processed = Processed(p, [], p.seed, "")
|
329 |
+
file.write(processed.infotext(p, 0))
|
330 |
+
|
331 |
+
devices.torch_gc()
|
332 |
+
|
333 |
+
seed = get_fixed_seed(p.seed)
|
334 |
+
subseed = get_fixed_seed(p.subseed)
|
335 |
+
|
336 |
+
if p.outpath_samples is not None:
|
337 |
+
os.makedirs(p.outpath_samples, exist_ok=True)
|
338 |
+
|
339 |
+
if p.outpath_grids is not None:
|
340 |
+
os.makedirs(p.outpath_grids, exist_ok=True)
|
341 |
+
|
342 |
+
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
343 |
+
modules.sd_hijack.model_hijack.clear_comments()
|
344 |
+
|
345 |
+
comments = {}
|
346 |
+
|
347 |
+
shared.prompt_styles.apply_styles(p)
|
348 |
+
|
349 |
+
if type(p.prompt) == list:
|
350 |
+
all_prompts = p.prompt
|
351 |
+
else:
|
352 |
+
all_prompts = p.batch_size * p.n_iter * [p.prompt]
|
353 |
+
|
354 |
+
if type(seed) == list:
|
355 |
+
all_seeds = seed
|
356 |
+
else:
|
357 |
+
all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
|
358 |
+
|
359 |
+
if type(subseed) == list:
|
360 |
+
all_subseeds = subseed
|
361 |
+
else:
|
362 |
+
all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
|
363 |
+
|
364 |
+
def infotext(iteration=0, position_in_batch=0):
|
365 |
+
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
366 |
+
|
367 |
+
if os.path.exists(cmd_opts.embeddings_dir):
|
368 |
+
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
369 |
+
|
370 |
+
infotexts = []
|
371 |
+
output_images = []
|
372 |
+
|
373 |
+
with torch.no_grad(), p.sd_model.ema_scope():
|
374 |
+
with devices.autocast():
|
375 |
+
p.init(all_prompts, all_seeds, all_subseeds)
|
376 |
+
|
377 |
+
if state.job_count == -1:
|
378 |
+
state.job_count = p.n_iter
|
379 |
+
|
380 |
+
for n in range(p.n_iter):
|
381 |
+
if state.skipped:
|
382 |
+
state.skipped = False
|
383 |
+
|
384 |
+
if state.interrupted:
|
385 |
+
break
|
386 |
+
|
387 |
+
prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
388 |
+
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
389 |
+
subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
390 |
+
|
391 |
+
if (len(prompts) == 0):
|
392 |
+
break
|
393 |
+
|
394 |
+
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
|
395 |
+
#c = p.sd_model.get_learned_conditioning(prompts)
|
396 |
+
with devices.autocast():
|
397 |
+
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
398 |
+
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
|
399 |
+
|
400 |
+
if len(model_hijack.comments) > 0:
|
401 |
+
for comment in model_hijack.comments:
|
402 |
+
comments[comment] = 1
|
403 |
+
|
404 |
+
if p.n_iter > 1:
|
405 |
+
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
406 |
+
|
407 |
+
with devices.autocast():
|
408 |
+
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
409 |
+
|
410 |
+
if state.interrupted or state.skipped:
|
411 |
+
|
412 |
+
# if we are interrupted, sample returns just noise
|
413 |
+
# use the image collected previously in sampler loop
|
414 |
+
samples_ddim = shared.state.current_latent
|
415 |
+
|
416 |
+
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
417 |
+
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
418 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
419 |
+
|
420 |
+
del samples_ddim
|
421 |
+
|
422 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
423 |
+
lowvram.send_everything_to_cpu()
|
424 |
+
|
425 |
+
devices.torch_gc()
|
426 |
+
|
427 |
+
if opts.filter_nsfw:
|
428 |
+
import modules.safety as safety
|
429 |
+
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
430 |
+
|
431 |
+
for i, x_sample in enumerate(x_samples_ddim):
|
432 |
+
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
433 |
+
x_sample = x_sample.astype(np.uint8)
|
434 |
+
|
435 |
+
if p.restore_faces:
|
436 |
+
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
437 |
+
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
438 |
+
|
439 |
+
devices.torch_gc()
|
440 |
+
|
441 |
+
x_sample = modules.face_restoration.restore_faces(x_sample)
|
442 |
+
devices.torch_gc()
|
443 |
+
|
444 |
+
image = Image.fromarray(x_sample)
|
445 |
+
|
446 |
+
if p.color_corrections is not None and i < len(p.color_corrections):
|
447 |
+
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
448 |
+
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
|
449 |
+
image = apply_color_correction(p.color_corrections[i], image)
|
450 |
+
|
451 |
+
if p.overlay_images is not None and i < len(p.overlay_images):
|
452 |
+
overlay = p.overlay_images[i]
|
453 |
+
|
454 |
+
if p.paste_to is not None:
|
455 |
+
x, y, w, h = p.paste_to
|
456 |
+
base_image = Image.new('RGBA', (overlay.width, overlay.height))
|
457 |
+
image = images.resize_image(1, image, w, h)
|
458 |
+
base_image.paste(image, (x, y))
|
459 |
+
image = base_image
|
460 |
+
|
461 |
+
image = image.convert('RGBA')
|
462 |
+
image.alpha_composite(overlay)
|
463 |
+
image = image.convert('RGB')
|
464 |
+
|
465 |
+
if opts.samples_save and not p.do_not_save_samples:
|
466 |
+
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
467 |
+
|
468 |
+
text = infotext(n, i)
|
469 |
+
infotexts.append(text)
|
470 |
+
if opts.enable_pnginfo:
|
471 |
+
image.info["parameters"] = text
|
472 |
+
output_images.append(image)
|
473 |
+
|
474 |
+
del x_samples_ddim
|
475 |
+
|
476 |
+
devices.torch_gc()
|
477 |
+
|
478 |
+
state.nextjob()
|
479 |
+
|
480 |
+
p.color_corrections = None
|
481 |
+
|
482 |
+
index_of_first_image = 0
|
483 |
+
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
|
484 |
+
if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
|
485 |
+
grid = images.image_grid(output_images, p.batch_size)
|
486 |
+
|
487 |
+
if opts.return_grid:
|
488 |
+
text = infotext()
|
489 |
+
infotexts.insert(0, text)
|
490 |
+
if opts.enable_pnginfo:
|
491 |
+
grid.info["parameters"] = text
|
492 |
+
output_images.insert(0, grid)
|
493 |
+
index_of_first_image = 1
|
494 |
+
|
495 |
+
if opts.grid_save:
|
496 |
+
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
497 |
+
|
498 |
+
devices.torch_gc()
|
499 |
+
return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
500 |
+
|
501 |
+
|
502 |
+
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
503 |
+
sampler = None
|
504 |
+
firstphase_width = 0
|
505 |
+
firstphase_height = 0
|
506 |
+
firstphase_width_truncated = 0
|
507 |
+
firstphase_height_truncated = 0
|
508 |
+
|
509 |
+
def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
|
510 |
+
super().__init__(**kwargs)
|
511 |
+
self.enable_hr = enable_hr
|
512 |
+
self.scale_latent = scale_latent
|
513 |
+
self.denoising_strength = denoising_strength
|
514 |
+
|
515 |
+
def init(self, all_prompts, all_seeds, all_subseeds):
|
516 |
+
if self.enable_hr:
|
517 |
+
if state.job_count == -1:
|
518 |
+
state.job_count = self.n_iter * 2
|
519 |
+
else:
|
520 |
+
state.job_count = state.job_count * 2
|
521 |
+
|
522 |
+
desired_pixel_count = 512 * 512
|
523 |
+
actual_pixel_count = self.width * self.height
|
524 |
+
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
525 |
+
|
526 |
+
self.firstphase_width = math.ceil(scale * self.width / 64) * 64
|
527 |
+
self.firstphase_height = math.ceil(scale * self.height / 64) * 64
|
528 |
+
self.firstphase_width_truncated = int(scale * self.width)
|
529 |
+
self.firstphase_height_truncated = int(scale * self.height)
|
530 |
+
|
531 |
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
532 |
+
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
533 |
+
|
534 |
+
if not self.enable_hr:
|
535 |
+
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
536 |
+
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
537 |
+
return samples
|
538 |
+
|
539 |
+
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
540 |
+
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
541 |
+
|
542 |
+
truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
|
543 |
+
truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
|
544 |
+
|
545 |
+
samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
|
546 |
+
|
547 |
+
if self.scale_latent:
|
548 |
+
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
549 |
+
else:
|
550 |
+
decoded_samples = decode_first_stage(self.sd_model, samples)
|
551 |
+
|
552 |
+
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
|
553 |
+
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
|
554 |
+
else:
|
555 |
+
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
556 |
+
|
557 |
+
batch_images = []
|
558 |
+
for i, x_sample in enumerate(lowres_samples):
|
559 |
+
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
560 |
+
x_sample = x_sample.astype(np.uint8)
|
561 |
+
image = Image.fromarray(x_sample)
|
562 |
+
image = images.resize_image(0, image, self.width, self.height)
|
563 |
+
image = np.array(image).astype(np.float32) / 255.0
|
564 |
+
image = np.moveaxis(image, 2, 0)
|
565 |
+
batch_images.append(image)
|
566 |
+
|
567 |
+
decoded_samples = torch.from_numpy(np.array(batch_images))
|
568 |
+
decoded_samples = decoded_samples.to(shared.device)
|
569 |
+
decoded_samples = 2. * decoded_samples - 1.
|
570 |
+
|
571 |
+
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
572 |
+
|
573 |
+
shared.state.nextjob()
|
574 |
+
|
575 |
+
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
576 |
+
|
577 |
+
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
578 |
+
|
579 |
+
# GC now before running the next img2img to prevent running out of memory
|
580 |
+
x = None
|
581 |
+
devices.torch_gc()
|
582 |
+
|
583 |
+
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
|
584 |
+
|
585 |
+
return samples
|
586 |
+
|
587 |
+
|
588 |
+
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
589 |
+
sampler = None
|
590 |
+
|
591 |
+
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
|
592 |
+
super().__init__(**kwargs)
|
593 |
+
|
594 |
+
self.init_images = init_images
|
595 |
+
self.resize_mode: int = resize_mode
|
596 |
+
self.denoising_strength: float = denoising_strength
|
597 |
+
self.init_latent = None
|
598 |
+
self.image_mask = mask
|
599 |
+
#self.image_unblurred_mask = None
|
600 |
+
self.latent_mask = None
|
601 |
+
self.mask_for_overlay = None
|
602 |
+
self.mask_blur = mask_blur
|
603 |
+
self.inpainting_fill = inpainting_fill
|
604 |
+
self.inpaint_full_res = inpaint_full_res
|
605 |
+
self.inpaint_full_res_padding = inpaint_full_res_padding
|
606 |
+
self.inpainting_mask_invert = inpainting_mask_invert
|
607 |
+
self.mask = None
|
608 |
+
self.nmask = None
|
609 |
+
|
610 |
+
def init(self, all_prompts, all_seeds, all_subseeds):
|
611 |
+
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
612 |
+
crop_region = None
|
613 |
+
|
614 |
+
if self.image_mask is not None:
|
615 |
+
self.image_mask = self.image_mask.convert('L')
|
616 |
+
|
617 |
+
if self.inpainting_mask_invert:
|
618 |
+
self.image_mask = ImageOps.invert(self.image_mask)
|
619 |
+
|
620 |
+
#self.image_unblurred_mask = self.image_mask
|
621 |
+
|
622 |
+
if self.mask_blur > 0:
|
623 |
+
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
624 |
+
|
625 |
+
if self.inpaint_full_res:
|
626 |
+
self.mask_for_overlay = self.image_mask
|
627 |
+
mask = self.image_mask.convert('L')
|
628 |
+
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
|
629 |
+
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
|
630 |
+
x1, y1, x2, y2 = crop_region
|
631 |
+
|
632 |
+
mask = mask.crop(crop_region)
|
633 |
+
self.image_mask = images.resize_image(2, mask, self.width, self.height)
|
634 |
+
self.paste_to = (x1, y1, x2-x1, y2-y1)
|
635 |
+
else:
|
636 |
+
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
|
637 |
+
np_mask = np.array(self.image_mask)
|
638 |
+
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
|
639 |
+
self.mask_for_overlay = Image.fromarray(np_mask)
|
640 |
+
|
641 |
+
self.overlay_images = []
|
642 |
+
|
643 |
+
latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
|
644 |
+
|
645 |
+
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
|
646 |
+
if add_color_corrections:
|
647 |
+
self.color_corrections = []
|
648 |
+
imgs = []
|
649 |
+
for img in self.init_images:
|
650 |
+
image = img.convert("RGB")
|
651 |
+
|
652 |
+
if crop_region is None:
|
653 |
+
image = images.resize_image(self.resize_mode, image, self.width, self.height)
|
654 |
+
|
655 |
+
if self.image_mask is not None:
|
656 |
+
image_masked = Image.new('RGBa', (image.width, image.height))
|
657 |
+
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
|
658 |
+
|
659 |
+
self.overlay_images.append(image_masked.convert('RGBA'))
|
660 |
+
|
661 |
+
if crop_region is not None:
|
662 |
+
image = image.crop(crop_region)
|
663 |
+
image = images.resize_image(2, image, self.width, self.height)
|
664 |
+
|
665 |
+
if self.image_mask is not None:
|
666 |
+
if self.inpainting_fill != 1:
|
667 |
+
image = masking.fill(image, latent_mask)
|
668 |
+
|
669 |
+
if add_color_corrections:
|
670 |
+
self.color_corrections.append(setup_color_correction(image))
|
671 |
+
|
672 |
+
image = np.array(image).astype(np.float32) / 255.0
|
673 |
+
image = np.moveaxis(image, 2, 0)
|
674 |
+
|
675 |
+
imgs.append(image)
|
676 |
+
|
677 |
+
if len(imgs) == 1:
|
678 |
+
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
|
679 |
+
if self.overlay_images is not None:
|
680 |
+
self.overlay_images = self.overlay_images * self.batch_size
|
681 |
+
elif len(imgs) <= self.batch_size:
|
682 |
+
self.batch_size = len(imgs)
|
683 |
+
batch_images = np.array(imgs)
|
684 |
+
else:
|
685 |
+
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
686 |
+
|
687 |
+
image = torch.from_numpy(batch_images)
|
688 |
+
image = 2. * image - 1.
|
689 |
+
image = image.to(shared.device)
|
690 |
+
|
691 |
+
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
692 |
+
|
693 |
+
if self.image_mask is not None:
|
694 |
+
init_mask = latent_mask
|
695 |
+
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
696 |
+
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
697 |
+
latmask = latmask[0]
|
698 |
+
latmask = np.around(latmask)
|
699 |
+
latmask = np.tile(latmask[None], (4, 1, 1))
|
700 |
+
|
701 |
+
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
702 |
+
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
|
703 |
+
|
704 |
+
# this needs to be fixed to be done in sample() using actual seeds for batches
|
705 |
+
if self.inpainting_fill == 2:
|
706 |
+
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
707 |
+
elif self.inpainting_fill == 3:
|
708 |
+
self.init_latent = self.init_latent * self.mask
|
709 |
+
|
710 |
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
711 |
+
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
712 |
+
|
713 |
+
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
|
714 |
+
|
715 |
+
if self.mask is not None:
|
716 |
+
samples = samples * self.nmask + self.init_latent * self.mask
|
717 |
+
|
718 |
+
del x
|
719 |
+
devices.torch_gc()
|
720 |
+
|
721 |
+
return samples
|
modules/prompt_parser.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from collections import namedtuple
|
3 |
+
from typing import List
|
4 |
+
import lark
|
5 |
+
|
6 |
+
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
|
7 |
+
# will be represented with prompt_schedule like this (assuming steps=100):
|
8 |
+
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
|
9 |
+
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
|
10 |
+
# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
|
11 |
+
# [75, 'fantasy landscape with a lake and an oak in background masterful']
|
12 |
+
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
|
13 |
+
|
14 |
+
schedule_parser = lark.Lark(r"""
|
15 |
+
!start: (prompt | /[][():]/+)*
|
16 |
+
prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
|
17 |
+
!emphasized: "(" prompt ")"
|
18 |
+
| "(" prompt ":" prompt ")"
|
19 |
+
| "[" prompt "]"
|
20 |
+
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
21 |
+
alternate: "[" prompt ("|" prompt)+ "]"
|
22 |
+
WHITESPACE: /\s+/
|
23 |
+
plain: /([^\\\[\]():|]|\\.)+/
|
24 |
+
%import common.SIGNED_NUMBER -> NUMBER
|
25 |
+
""")
|
26 |
+
|
27 |
+
def get_learned_conditioning_prompt_schedules(prompts, steps):
|
28 |
+
"""
|
29 |
+
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
30 |
+
>>> g("test")
|
31 |
+
[[10, 'test']]
|
32 |
+
>>> g("a [b:3]")
|
33 |
+
[[3, 'a '], [10, 'a b']]
|
34 |
+
>>> g("a [b: 3]")
|
35 |
+
[[3, 'a '], [10, 'a b']]
|
36 |
+
>>> g("a [[[b]]:2]")
|
37 |
+
[[2, 'a '], [10, 'a [[b]]']]
|
38 |
+
>>> g("[(a:2):3]")
|
39 |
+
[[3, ''], [10, '(a:2)']]
|
40 |
+
>>> g("a [b : c : 1] d")
|
41 |
+
[[1, 'a b d'], [10, 'a c d']]
|
42 |
+
>>> g("a[b:[c:d:2]:1]e")
|
43 |
+
[[1, 'abe'], [2, 'ace'], [10, 'ade']]
|
44 |
+
>>> g("a [unbalanced")
|
45 |
+
[[10, 'a [unbalanced']]
|
46 |
+
>>> g("a [b:.5] c")
|
47 |
+
[[5, 'a c'], [10, 'a b c']]
|
48 |
+
>>> g("a [{b|d{:.5] c") # not handling this right now
|
49 |
+
[[5, 'a c'], [10, 'a {b|d{ c']]
|
50 |
+
>>> g("((a][:b:c [d:3]")
|
51 |
+
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
52 |
+
"""
|
53 |
+
|
54 |
+
def collect_steps(steps, tree):
|
55 |
+
l = [steps]
|
56 |
+
class CollectSteps(lark.Visitor):
|
57 |
+
def scheduled(self, tree):
|
58 |
+
tree.children[-1] = float(tree.children[-1])
|
59 |
+
if tree.children[-1] < 1:
|
60 |
+
tree.children[-1] *= steps
|
61 |
+
tree.children[-1] = min(steps, int(tree.children[-1]))
|
62 |
+
l.append(tree.children[-1])
|
63 |
+
def alternate(self, tree):
|
64 |
+
l.extend(range(1, steps+1))
|
65 |
+
CollectSteps().visit(tree)
|
66 |
+
return sorted(set(l))
|
67 |
+
|
68 |
+
def at_step(step, tree):
|
69 |
+
class AtStep(lark.Transformer):
|
70 |
+
def scheduled(self, args):
|
71 |
+
before, after, _, when = args
|
72 |
+
yield before or () if step <= when else after
|
73 |
+
def alternate(self, args):
|
74 |
+
yield next(args[(step - 1)%len(args)])
|
75 |
+
def start(self, args):
|
76 |
+
def flatten(x):
|
77 |
+
if type(x) == str:
|
78 |
+
yield x
|
79 |
+
else:
|
80 |
+
for gen in x:
|
81 |
+
yield from flatten(gen)
|
82 |
+
return ''.join(flatten(args))
|
83 |
+
def plain(self, args):
|
84 |
+
yield args[0].value
|
85 |
+
def __default__(self, data, children, meta):
|
86 |
+
for child in children:
|
87 |
+
yield from child
|
88 |
+
return AtStep().transform(tree)
|
89 |
+
|
90 |
+
def get_schedule(prompt):
|
91 |
+
try:
|
92 |
+
tree = schedule_parser.parse(prompt)
|
93 |
+
except lark.exceptions.LarkError as e:
|
94 |
+
if 0:
|
95 |
+
import traceback
|
96 |
+
traceback.print_exc()
|
97 |
+
return [[steps, prompt]]
|
98 |
+
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
|
99 |
+
|
100 |
+
promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
|
101 |
+
return [promptdict[prompt] for prompt in prompts]
|
102 |
+
|
103 |
+
|
104 |
+
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
105 |
+
|
106 |
+
|
107 |
+
def get_learned_conditioning(model, prompts, steps):
|
108 |
+
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
109 |
+
and the sampling step at which this condition is to be replaced by the next one.
|
110 |
+
|
111 |
+
Input:
|
112 |
+
(model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
|
113 |
+
|
114 |
+
Output:
|
115 |
+
[
|
116 |
+
[
|
117 |
+
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
|
118 |
+
],
|
119 |
+
[
|
120 |
+
ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
|
121 |
+
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
|
122 |
+
]
|
123 |
+
]
|
124 |
+
"""
|
125 |
+
res = []
|
126 |
+
|
127 |
+
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
128 |
+
cache = {}
|
129 |
+
|
130 |
+
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
|
131 |
+
|
132 |
+
cached = cache.get(prompt, None)
|
133 |
+
if cached is not None:
|
134 |
+
res.append(cached)
|
135 |
+
continue
|
136 |
+
|
137 |
+
texts = [x[1] for x in prompt_schedule]
|
138 |
+
conds = model.get_learned_conditioning(texts)
|
139 |
+
|
140 |
+
cond_schedule = []
|
141 |
+
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
142 |
+
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
143 |
+
|
144 |
+
cache[prompt] = cond_schedule
|
145 |
+
res.append(cond_schedule)
|
146 |
+
|
147 |
+
return res
|
148 |
+
|
149 |
+
|
150 |
+
re_AND = re.compile(r"\bAND\b")
|
151 |
+
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
|
152 |
+
|
153 |
+
def get_multicond_prompt_list(prompts):
|
154 |
+
res_indexes = []
|
155 |
+
|
156 |
+
prompt_flat_list = []
|
157 |
+
prompt_indexes = {}
|
158 |
+
|
159 |
+
for prompt in prompts:
|
160 |
+
subprompts = re_AND.split(prompt)
|
161 |
+
|
162 |
+
indexes = []
|
163 |
+
for subprompt in subprompts:
|
164 |
+
match = re_weight.search(subprompt)
|
165 |
+
|
166 |
+
text, weight = match.groups() if match is not None else (subprompt, 1.0)
|
167 |
+
|
168 |
+
weight = float(weight) if weight is not None else 1.0
|
169 |
+
|
170 |
+
index = prompt_indexes.get(text, None)
|
171 |
+
if index is None:
|
172 |
+
index = len(prompt_flat_list)
|
173 |
+
prompt_flat_list.append(text)
|
174 |
+
prompt_indexes[text] = index
|
175 |
+
|
176 |
+
indexes.append((index, weight))
|
177 |
+
|
178 |
+
res_indexes.append(indexes)
|
179 |
+
|
180 |
+
return res_indexes, prompt_flat_list, prompt_indexes
|
181 |
+
|
182 |
+
|
183 |
+
class ComposableScheduledPromptConditioning:
|
184 |
+
def __init__(self, schedules, weight=1.0):
|
185 |
+
self.schedules: List[ScheduledPromptConditioning] = schedules
|
186 |
+
self.weight: float = weight
|
187 |
+
|
188 |
+
|
189 |
+
class MulticondLearnedConditioning:
|
190 |
+
def __init__(self, shape, batch):
|
191 |
+
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
192 |
+
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
193 |
+
|
194 |
+
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
195 |
+
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
196 |
+
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
197 |
+
|
198 |
+
https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
|
199 |
+
"""
|
200 |
+
|
201 |
+
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
202 |
+
|
203 |
+
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
|
204 |
+
|
205 |
+
res = []
|
206 |
+
for indexes in res_indexes:
|
207 |
+
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
|
208 |
+
|
209 |
+
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
210 |
+
|
211 |
+
|
212 |
+
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
|
213 |
+
param = c[0][0].cond
|
214 |
+
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
215 |
+
for i, cond_schedule in enumerate(c):
|
216 |
+
target_index = 0
|
217 |
+
for current, (end_at, cond) in enumerate(cond_schedule):
|
218 |
+
if current_step <= end_at:
|
219 |
+
target_index = current
|
220 |
+
break
|
221 |
+
res[i] = cond_schedule[target_index].cond
|
222 |
+
|
223 |
+
return res
|
224 |
+
|
225 |
+
|
226 |
+
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
227 |
+
param = c.batch[0][0].schedules[0].cond
|
228 |
+
|
229 |
+
tensors = []
|
230 |
+
conds_list = []
|
231 |
+
|
232 |
+
for batch_no, composable_prompts in enumerate(c.batch):
|
233 |
+
conds_for_batch = []
|
234 |
+
|
235 |
+
for cond_index, composable_prompt in enumerate(composable_prompts):
|
236 |
+
target_index = 0
|
237 |
+
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
|
238 |
+
if current_step <= end_at:
|
239 |
+
target_index = current
|
240 |
+
break
|
241 |
+
|
242 |
+
conds_for_batch.append((len(tensors), composable_prompt.weight))
|
243 |
+
tensors.append(composable_prompt.schedules[target_index].cond)
|
244 |
+
|
245 |
+
conds_list.append(conds_for_batch)
|
246 |
+
|
247 |
+
# if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
|
248 |
+
# and won't be able to torch.stack them. So this fixes that.
|
249 |
+
token_count = max([x.shape[0] for x in tensors])
|
250 |
+
for i in range(len(tensors)):
|
251 |
+
if tensors[i].shape[0] != token_count:
|
252 |
+
last_vector = tensors[i][-1:]
|
253 |
+
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
254 |
+
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
255 |
+
|
256 |
+
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|
257 |
+
|
258 |
+
|
259 |
+
re_attention = re.compile(r"""
|
260 |
+
\\\(|
|
261 |
+
\\\)|
|
262 |
+
\\\[|
|
263 |
+
\\]|
|
264 |
+
\\\\|
|
265 |
+
\\|
|
266 |
+
\(|
|
267 |
+
\[|
|
268 |
+
:([+-]?[.\d]+)\)|
|
269 |
+
\)|
|
270 |
+
]|
|
271 |
+
[^\\()\[\]:]+|
|
272 |
+
:
|
273 |
+
""", re.X)
|
274 |
+
|
275 |
+
|
276 |
+
def parse_prompt_attention(text):
|
277 |
+
"""
|
278 |
+
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
|
279 |
+
Accepted tokens are:
|
280 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
281 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
282 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
283 |
+
\( - literal character '('
|
284 |
+
\[ - literal character '['
|
285 |
+
\) - literal character ')'
|
286 |
+
\] - literal character ']'
|
287 |
+
\\ - literal character '\'
|
288 |
+
anything else - just text
|
289 |
+
|
290 |
+
>>> parse_prompt_attention('normal text')
|
291 |
+
[['normal text', 1.0]]
|
292 |
+
>>> parse_prompt_attention('an (important) word')
|
293 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
294 |
+
>>> parse_prompt_attention('(unbalanced')
|
295 |
+
[['unbalanced', 1.1]]
|
296 |
+
>>> parse_prompt_attention('\(literal\]')
|
297 |
+
[['(literal]', 1.0]]
|
298 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
299 |
+
[['unnecessaryparens', 1.1]]
|
300 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
301 |
+
[['a ', 1.0],
|
302 |
+
['house', 1.5730000000000004],
|
303 |
+
[' ', 1.1],
|
304 |
+
['on', 1.0],
|
305 |
+
[' a ', 1.1],
|
306 |
+
['hill', 0.55],
|
307 |
+
[', sun, ', 1.1],
|
308 |
+
['sky', 1.4641000000000006],
|
309 |
+
['.', 1.1]]
|
310 |
+
"""
|
311 |
+
|
312 |
+
res = []
|
313 |
+
round_brackets = []
|
314 |
+
square_brackets = []
|
315 |
+
|
316 |
+
round_bracket_multiplier = 1.1
|
317 |
+
square_bracket_multiplier = 1 / 1.1
|
318 |
+
|
319 |
+
def multiply_range(start_position, multiplier):
|
320 |
+
for p in range(start_position, len(res)):
|
321 |
+
res[p][1] *= multiplier
|
322 |
+
|
323 |
+
for m in re_attention.finditer(text):
|
324 |
+
text = m.group(0)
|
325 |
+
weight = m.group(1)
|
326 |
+
|
327 |
+
if text.startswith('\\'):
|
328 |
+
res.append([text[1:], 1.0])
|
329 |
+
elif text == '(':
|
330 |
+
round_brackets.append(len(res))
|
331 |
+
elif text == '[':
|
332 |
+
square_brackets.append(len(res))
|
333 |
+
elif weight is not None and len(round_brackets) > 0:
|
334 |
+
multiply_range(round_brackets.pop(), float(weight))
|
335 |
+
elif text == ')' and len(round_brackets) > 0:
|
336 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
337 |
+
elif text == ']' and len(square_brackets) > 0:
|
338 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
339 |
+
else:
|
340 |
+
res.append([text, 1.0])
|
341 |
+
|
342 |
+
for pos in round_brackets:
|
343 |
+
multiply_range(pos, round_bracket_multiplier)
|
344 |
+
|
345 |
+
for pos in square_brackets:
|
346 |
+
multiply_range(pos, square_bracket_multiplier)
|
347 |
+
|
348 |
+
if len(res) == 0:
|
349 |
+
res = [["", 1.0]]
|
350 |
+
|
351 |
+
# merge runs of identical weights
|
352 |
+
i = 0
|
353 |
+
while i + 1 < len(res):
|
354 |
+
if res[i][1] == res[i + 1][1]:
|
355 |
+
res[i][0] += res[i + 1][0]
|
356 |
+
res.pop(i + 1)
|
357 |
+
else:
|
358 |
+
i += 1
|
359 |
+
|
360 |
+
return res
|
361 |
+
|
362 |
+
if __name__ == "__main__":
|
363 |
+
import doctest
|
364 |
+
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
|
365 |
+
else:
|
366 |
+
import torch # doctest faster
|
modules/realesrgan_model.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from basicsr.utils.download_util import load_file_from_url
|
8 |
+
from realesrgan import RealESRGANer
|
9 |
+
|
10 |
+
from modules.upscaler import Upscaler, UpscalerData
|
11 |
+
from modules.shared import cmd_opts, opts
|
12 |
+
|
13 |
+
|
14 |
+
class UpscalerRealESRGAN(Upscaler):
|
15 |
+
def __init__(self, path):
|
16 |
+
self.name = "RealESRGAN"
|
17 |
+
self.user_path = path
|
18 |
+
super().__init__()
|
19 |
+
try:
|
20 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
21 |
+
from realesrgan import RealESRGANer
|
22 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
23 |
+
self.enable = True
|
24 |
+
self.scalers = []
|
25 |
+
scalers = self.load_models(path)
|
26 |
+
for scaler in scalers:
|
27 |
+
if scaler.name in opts.realesrgan_enabled_models:
|
28 |
+
self.scalers.append(scaler)
|
29 |
+
|
30 |
+
except Exception:
|
31 |
+
print("Error importing Real-ESRGAN:", file=sys.stderr)
|
32 |
+
print(traceback.format_exc(), file=sys.stderr)
|
33 |
+
self.enable = False
|
34 |
+
self.scalers = []
|
35 |
+
|
36 |
+
def do_upscale(self, img, path):
|
37 |
+
if not self.enable:
|
38 |
+
return img
|
39 |
+
|
40 |
+
info = self.load_model(path)
|
41 |
+
if not os.path.exists(info.data_path):
|
42 |
+
print("Unable to load RealESRGAN model: %s" % info.name)
|
43 |
+
return img
|
44 |
+
|
45 |
+
upsampler = RealESRGANer(
|
46 |
+
scale=info.scale,
|
47 |
+
model_path=info.data_path,
|
48 |
+
model=info.model(),
|
49 |
+
half=not cmd_opts.no_half,
|
50 |
+
tile=opts.ESRGAN_tile,
|
51 |
+
tile_pad=opts.ESRGAN_tile_overlap,
|
52 |
+
)
|
53 |
+
|
54 |
+
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
55 |
+
|
56 |
+
image = Image.fromarray(upsampled)
|
57 |
+
return image
|
58 |
+
|
59 |
+
def load_model(self, path):
|
60 |
+
try:
|
61 |
+
info = None
|
62 |
+
for scaler in self.scalers:
|
63 |
+
if scaler.data_path == path:
|
64 |
+
info = scaler
|
65 |
+
|
66 |
+
if info is None:
|
67 |
+
print(f"Unable to find model info: {path}")
|
68 |
+
return None
|
69 |
+
|
70 |
+
model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
71 |
+
info.data_path = model_file
|
72 |
+
return info
|
73 |
+
except Exception as e:
|
74 |
+
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
75 |
+
print(traceback.format_exc(), file=sys.stderr)
|
76 |
+
return None
|
77 |
+
|
78 |
+
def load_models(self, _):
|
79 |
+
return get_realesrgan_models(self)
|
80 |
+
|
81 |
+
|
82 |
+
def get_realesrgan_models(scaler):
|
83 |
+
try:
|
84 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
85 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
86 |
+
models = [
|
87 |
+
UpscalerData(
|
88 |
+
name="R-ESRGAN General 4xV3",
|
89 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
90 |
+
scale=4,
|
91 |
+
upscaler=scaler,
|
92 |
+
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
93 |
+
),
|
94 |
+
UpscalerData(
|
95 |
+
name="R-ESRGAN General WDN 4xV3",
|
96 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
97 |
+
scale=4,
|
98 |
+
upscaler=scaler,
|
99 |
+
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
100 |
+
),
|
101 |
+
UpscalerData(
|
102 |
+
name="R-ESRGAN AnimeVideo",
|
103 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
104 |
+
scale=4,
|
105 |
+
upscaler=scaler,
|
106 |
+
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
107 |
+
),
|
108 |
+
UpscalerData(
|
109 |
+
name="R-ESRGAN 4x+",
|
110 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
111 |
+
scale=4,
|
112 |
+
upscaler=scaler,
|
113 |
+
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
114 |
+
),
|
115 |
+
UpscalerData(
|
116 |
+
name="R-ESRGAN 4x+ Anime6B",
|
117 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
118 |
+
scale=4,
|
119 |
+
upscaler=scaler,
|
120 |
+
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
121 |
+
),
|
122 |
+
UpscalerData(
|
123 |
+
name="R-ESRGAN 2x+",
|
124 |
+
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
125 |
+
scale=2,
|
126 |
+
upscaler=scaler,
|
127 |
+
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
128 |
+
),
|
129 |
+
]
|
130 |
+
return models
|
131 |
+
except Exception as e:
|
132 |
+
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
133 |
+
print(traceback.format_exc(), file=sys.stderr)
|
modules/safe.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this code is adapted from the script contributed by anon from /h/
|
2 |
+
|
3 |
+
import io
|
4 |
+
import pickle
|
5 |
+
import collections
|
6 |
+
import sys
|
7 |
+
import traceback
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import numpy
|
11 |
+
import _codecs
|
12 |
+
import zipfile
|
13 |
+
import re
|
14 |
+
|
15 |
+
|
16 |
+
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
17 |
+
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
18 |
+
|
19 |
+
|
20 |
+
def encode(*args):
|
21 |
+
out = _codecs.encode(*args)
|
22 |
+
return out
|
23 |
+
|
24 |
+
|
25 |
+
class RestrictedUnpickler(pickle.Unpickler):
|
26 |
+
def persistent_load(self, saved_id):
|
27 |
+
assert saved_id[0] == 'storage'
|
28 |
+
return TypedStorage()
|
29 |
+
|
30 |
+
def find_class(self, module, name):
|
31 |
+
if module == 'collections' and name == 'OrderedDict':
|
32 |
+
return getattr(collections, name)
|
33 |
+
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
|
34 |
+
return getattr(torch._utils, name)
|
35 |
+
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
|
36 |
+
return getattr(torch, name)
|
37 |
+
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
38 |
+
return getattr(torch.nn.modules.container, name)
|
39 |
+
if module == 'numpy.core.multiarray' and name == 'scalar':
|
40 |
+
return numpy.core.multiarray.scalar
|
41 |
+
if module == 'numpy' and name == 'dtype':
|
42 |
+
return numpy.dtype
|
43 |
+
if module == '_codecs' and name == 'encode':
|
44 |
+
return encode
|
45 |
+
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
46 |
+
import pytorch_lightning.callbacks
|
47 |
+
return pytorch_lightning.callbacks.model_checkpoint
|
48 |
+
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
49 |
+
import pytorch_lightning.callbacks.model_checkpoint
|
50 |
+
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
51 |
+
if module == "__builtin__" and name == 'set':
|
52 |
+
return set
|
53 |
+
|
54 |
+
# Forbid everything else.
|
55 |
+
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
56 |
+
|
57 |
+
|
58 |
+
allowed_zip_names = ["archive/data.pkl", "archive/version"]
|
59 |
+
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
|
60 |
+
|
61 |
+
|
62 |
+
def check_zip_filenames(filename, names):
|
63 |
+
for name in names:
|
64 |
+
if name in allowed_zip_names:
|
65 |
+
continue
|
66 |
+
if allowed_zip_names_re.match(name):
|
67 |
+
continue
|
68 |
+
|
69 |
+
raise Exception(f"bad file inside {filename}: {name}")
|
70 |
+
|
71 |
+
|
72 |
+
def check_pt(filename):
|
73 |
+
try:
|
74 |
+
|
75 |
+
# new pytorch format is a zip file
|
76 |
+
with zipfile.ZipFile(filename) as z:
|
77 |
+
check_zip_filenames(filename, z.namelist())
|
78 |
+
|
79 |
+
with z.open('archive/data.pkl') as file:
|
80 |
+
unpickler = RestrictedUnpickler(file)
|
81 |
+
unpickler.load()
|
82 |
+
|
83 |
+
except zipfile.BadZipfile:
|
84 |
+
|
85 |
+
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
86 |
+
with open(filename, "rb") as file:
|
87 |
+
unpickler = RestrictedUnpickler(file)
|
88 |
+
for i in range(5):
|
89 |
+
unpickler.load()
|
90 |
+
|
91 |
+
|
92 |
+
def load(filename, *args, **kwargs):
|
93 |
+
from modules import shared
|
94 |
+
|
95 |
+
try:
|
96 |
+
if not shared.cmd_opts.disable_safe_unpickle:
|
97 |
+
check_pt(filename)
|
98 |
+
|
99 |
+
except Exception:
|
100 |
+
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
101 |
+
print(traceback.format_exc(), file=sys.stderr)
|
102 |
+
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
103 |
+
print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr)
|
104 |
+
return None
|
105 |
+
|
106 |
+
return unsafe_torch_load(filename, *args, **kwargs)
|
107 |
+
|
108 |
+
|
109 |
+
unsafe_torch_load = torch.load
|
110 |
+
torch.load = load
|
modules/safety.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
3 |
+
from transformers import AutoFeatureExtractor
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import modules.shared as shared
|
7 |
+
|
8 |
+
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
9 |
+
safety_feature_extractor = None
|
10 |
+
safety_checker = None
|
11 |
+
|
12 |
+
def numpy_to_pil(images):
|
13 |
+
"""
|
14 |
+
Convert a numpy image or a batch of images to a PIL image.
|
15 |
+
"""
|
16 |
+
if images.ndim == 3:
|
17 |
+
images = images[None, ...]
|
18 |
+
images = (images * 255).round().astype("uint8")
|
19 |
+
pil_images = [Image.fromarray(image) for image in images]
|
20 |
+
|
21 |
+
return pil_images
|
22 |
+
|
23 |
+
# check and replace nsfw content
|
24 |
+
def check_safety(x_image):
|
25 |
+
global safety_feature_extractor, safety_checker
|
26 |
+
|
27 |
+
if safety_feature_extractor is None:
|
28 |
+
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
29 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
30 |
+
|
31 |
+
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
32 |
+
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
33 |
+
|
34 |
+
return x_checked_image, has_nsfw_concept
|
35 |
+
|
36 |
+
|
37 |
+
def censor_batch(x):
|
38 |
+
x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
|
39 |
+
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
|
40 |
+
x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
41 |
+
|
42 |
+
return x
|
modules/scripts.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import modules.ui as ui
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from modules.processing import StableDiffusionProcessing
|
9 |
+
from modules import shared
|
10 |
+
|
11 |
+
class Script:
|
12 |
+
filename = None
|
13 |
+
args_from = None
|
14 |
+
args_to = None
|
15 |
+
|
16 |
+
# The title of the script. This is what will be displayed in the dropdown menu.
|
17 |
+
def title(self):
|
18 |
+
raise NotImplementedError()
|
19 |
+
|
20 |
+
# How the script is displayed in the UI. See https://gradio.app/docs/#components
|
21 |
+
# for the different UI components you can use and how to create them.
|
22 |
+
# Most UI components can return a value, such as a boolean for a checkbox.
|
23 |
+
# The returned values are passed to the run method as parameters.
|
24 |
+
def ui(self, is_img2img):
|
25 |
+
pass
|
26 |
+
|
27 |
+
# Determines when the script should be shown in the dropdown menu via the
|
28 |
+
# returned value. As an example:
|
29 |
+
# is_img2img is True if the current tab is img2img, and False if it is txt2img.
|
30 |
+
# Thus, return is_img2img to only show the script on the img2img tab.
|
31 |
+
def show(self, is_img2img):
|
32 |
+
return True
|
33 |
+
|
34 |
+
# This is where the additional processing is implemented. The parameters include
|
35 |
+
# self, the model object "p" (a StableDiffusionProcessing class, see
|
36 |
+
# processing.py), and the parameters returned by the ui method.
|
37 |
+
# Custom functions can be defined here, and additional libraries can be imported
|
38 |
+
# to be used in processing. The return value should be a Processed object, which is
|
39 |
+
# what is returned by the process_images method.
|
40 |
+
def run(self, *args):
|
41 |
+
raise NotImplementedError()
|
42 |
+
|
43 |
+
# The description method is currently unused.
|
44 |
+
# To add a description that appears when hovering over the title, amend the "titles"
|
45 |
+
# dict in script.js to include the script title (returned by title) as a key, and
|
46 |
+
# your description as the value.
|
47 |
+
def describe(self):
|
48 |
+
return ""
|
49 |
+
|
50 |
+
|
51 |
+
scripts_data = []
|
52 |
+
|
53 |
+
|
54 |
+
def load_scripts(basedir):
|
55 |
+
if not os.path.exists(basedir):
|
56 |
+
return
|
57 |
+
|
58 |
+
for filename in sorted(os.listdir(basedir)):
|
59 |
+
path = os.path.join(basedir, filename)
|
60 |
+
|
61 |
+
if not os.path.isfile(path):
|
62 |
+
continue
|
63 |
+
|
64 |
+
try:
|
65 |
+
with open(path, "r", encoding="utf8") as file:
|
66 |
+
text = file.read()
|
67 |
+
|
68 |
+
from types import ModuleType
|
69 |
+
compiled = compile(text, path, 'exec')
|
70 |
+
module = ModuleType(filename)
|
71 |
+
exec(compiled, module.__dict__)
|
72 |
+
|
73 |
+
for key, script_class in module.__dict__.items():
|
74 |
+
if type(script_class) == type and issubclass(script_class, Script):
|
75 |
+
scripts_data.append((script_class, path))
|
76 |
+
|
77 |
+
except Exception:
|
78 |
+
print(f"Error loading script: {filename}", file=sys.stderr)
|
79 |
+
print(traceback.format_exc(), file=sys.stderr)
|
80 |
+
|
81 |
+
|
82 |
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
83 |
+
try:
|
84 |
+
res = func(*args, **kwargs)
|
85 |
+
return res
|
86 |
+
except Exception:
|
87 |
+
print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
|
88 |
+
print(traceback.format_exc(), file=sys.stderr)
|
89 |
+
|
90 |
+
return default
|
91 |
+
|
92 |
+
|
93 |
+
class ScriptRunner:
|
94 |
+
def __init__(self):
|
95 |
+
self.scripts = []
|
96 |
+
|
97 |
+
def setup_ui(self, is_img2img):
|
98 |
+
for script_class, path in scripts_data:
|
99 |
+
script = script_class()
|
100 |
+
script.filename = path
|
101 |
+
|
102 |
+
if not script.show(is_img2img):
|
103 |
+
continue
|
104 |
+
|
105 |
+
self.scripts.append(script)
|
106 |
+
|
107 |
+
titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
|
108 |
+
|
109 |
+
dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index")
|
110 |
+
inputs = [dropdown]
|
111 |
+
|
112 |
+
for script in self.scripts:
|
113 |
+
script.args_from = len(inputs)
|
114 |
+
script.args_to = len(inputs)
|
115 |
+
|
116 |
+
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
|
117 |
+
|
118 |
+
if controls is None:
|
119 |
+
continue
|
120 |
+
|
121 |
+
for control in controls:
|
122 |
+
control.custom_script_source = os.path.basename(script.filename)
|
123 |
+
control.visible = False
|
124 |
+
|
125 |
+
inputs += controls
|
126 |
+
script.args_to = len(inputs)
|
127 |
+
|
128 |
+
def select_script(script_index):
|
129 |
+
if 0 < script_index <= len(self.scripts):
|
130 |
+
script = self.scripts[script_index-1]
|
131 |
+
args_from = script.args_from
|
132 |
+
args_to = script.args_to
|
133 |
+
else:
|
134 |
+
args_from = 0
|
135 |
+
args_to = 0
|
136 |
+
|
137 |
+
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
|
138 |
+
|
139 |
+
dropdown.change(
|
140 |
+
fn=select_script,
|
141 |
+
inputs=[dropdown],
|
142 |
+
outputs=inputs
|
143 |
+
)
|
144 |
+
|
145 |
+
return inputs
|
146 |
+
|
147 |
+
def run(self, p: StableDiffusionProcessing, *args):
|
148 |
+
script_index = args[0]
|
149 |
+
|
150 |
+
if script_index == 0:
|
151 |
+
return None
|
152 |
+
|
153 |
+
script = self.scripts[script_index-1]
|
154 |
+
|
155 |
+
if script is None:
|
156 |
+
return None
|
157 |
+
|
158 |
+
script_args = args[script.args_from:script.args_to]
|
159 |
+
processed = script.run(p, *script_args)
|
160 |
+
|
161 |
+
shared.total_tqdm.clear()
|
162 |
+
|
163 |
+
return processed
|
164 |
+
|
165 |
+
def reload_sources(self):
|
166 |
+
for si, script in list(enumerate(self.scripts)):
|
167 |
+
with open(script.filename, "r", encoding="utf8") as file:
|
168 |
+
args_from = script.args_from
|
169 |
+
args_to = script.args_to
|
170 |
+
filename = script.filename
|
171 |
+
text = file.read()
|
172 |
+
|
173 |
+
from types import ModuleType
|
174 |
+
|
175 |
+
compiled = compile(text, filename, 'exec')
|
176 |
+
module = ModuleType(script.filename)
|
177 |
+
exec(compiled, module.__dict__)
|
178 |
+
|
179 |
+
for key, script_class in module.__dict__.items():
|
180 |
+
if type(script_class) == type and issubclass(script_class, Script):
|
181 |
+
self.scripts[si] = script_class()
|
182 |
+
self.scripts[si].filename = filename
|
183 |
+
self.scripts[si].args_from = args_from
|
184 |
+
self.scripts[si].args_to = args_to
|
185 |
+
|
186 |
+
scripts_txt2img = ScriptRunner()
|
187 |
+
scripts_img2img = ScriptRunner()
|
188 |
+
|
189 |
+
def reload_script_body_only():
|
190 |
+
scripts_txt2img.reload_sources()
|
191 |
+
scripts_img2img.reload_sources()
|
192 |
+
|
193 |
+
|
194 |
+
def reload_scripts(basedir):
|
195 |
+
global scripts_txt2img, scripts_img2img
|
196 |
+
|
197 |
+
scripts_data.clear()
|
198 |
+
load_scripts(basedir)
|
199 |
+
|
200 |
+
scripts_txt2img = ScriptRunner()
|
201 |
+
scripts_img2img = ScriptRunner()
|