13f5d210b8c95ad6b63872633d84859a0a6d9a4258ae3d0e1976b3e737b46fea
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- modules/textual_inversion/logging.py +64 -0
- modules/textual_inversion/preprocess.py +232 -0
- modules/textual_inversion/test_embedding.png +0 -0
- modules/textual_inversion/textual_inversion.py +683 -0
- modules/textual_inversion/ui.py +45 -0
- modules/timer.py +91 -0
- modules/txt2img.py +73 -0
- modules/ui.py +0 -0
- modules/ui_common.py +244 -0
- modules/ui_components.py +74 -0
- modules/ui_extensions.py +651 -0
- modules/ui_extra_networks.py +496 -0
- modules/ui_extra_networks_checkpoints.py +35 -0
- modules/ui_extra_networks_hypernets.py +35 -0
- modules/ui_extra_networks_textual_inversion.py +35 -0
- modules/ui_extra_networks_user_metadata.py +195 -0
- modules/ui_gradio_extensions.py +69 -0
- modules/ui_loadsave.py +210 -0
- modules/ui_postprocessing.py +57 -0
- modules/ui_settings.py +296 -0
- modules/ui_tempdir.py +85 -0
- modules/upscaler.py +144 -0
- modules/xlmr.py +137 -0
- outputs/txt2img-images/2023-07-30/00000-4104476258.png +0 -0
- outputs/txt2img-images/2023-07-30/00001-1264812310.png +0 -0
- outputs/txt2img-images/2023-07-30/00002-629074369.png +0 -0
- outputs/txt2img-images/2023-07-30/00003-3929529382.png +0 -0
- outputs/txt2img-images/2023-07-30/00004-2891905160.png +0 -0
- outputs/txt2img-images/2023-07-30/00005-1703927525.png +0 -0
- outputs/txt2img-images/2023-07-30/00006-1703927525.png +0 -0
- outputs/txt2img-images/2023-07-30/00007-1703927525.png +0 -0
- outputs/txt2img-images/2023-07-30/00008-1703927525.png +0 -0
- outputs/txt2img-images/2023-07-30/00009-1703927525.jpg +0 -0
- outputs/txt2img-images/2023-07-30/00010-210755578.jpg +0 -0
- outputs/txt2img-images/2023-07-30/00011-3978311133.jpg +0 -0
- outputs/txt2img-images/2023-07-30/00012-3786155085.jpg +0 -0
- outputs/txt2img-images/2023-07-30/00013-445379948.jpg +0 -0
- outputs/txt2img-images/2023-07-30/00014-3277595636.png +0 -0
- package.json +11 -0
- params.txt +3 -0
- pyproject.toml +35 -0
- repositories/BLIP/BLIP.gif +3 -0
- repositories/BLIP/CODEOWNERS +2 -0
- repositories/BLIP/CODE_OF_CONDUCT.md +105 -0
- repositories/BLIP/LICENSE.txt +12 -0
- repositories/BLIP/README.md +114 -0
- repositories/BLIP/SECURITY.md +7 -0
- repositories/BLIP/cog.yaml +17 -0
- repositories/BLIP/configs/bert_config.json +21 -0
.gitattributes
CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
extensions/Stable-Diffusion-Webui-Civitai-Helper/img/all_in_one.png filter=lfs diff=lfs merge=lfs -text
|
37 |
extensions/addtional/models/lora/README.md filter=lfs diff=lfs merge=lfs -text
|
|
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
extensions/Stable-Diffusion-Webui-Civitai-Helper/img/all_in_one.png filter=lfs diff=lfs merge=lfs -text
|
37 |
extensions/addtional/models/lora/README.md filter=lfs diff=lfs merge=lfs -text
|
38 |
+
repositories/BLIP/BLIP.gif filter=lfs diff=lfs merge=lfs -text
|
modules/textual_inversion/logging.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
saved_params_shared = {
|
6 |
+
"batch_size",
|
7 |
+
"clip_grad_mode",
|
8 |
+
"clip_grad_value",
|
9 |
+
"create_image_every",
|
10 |
+
"data_root",
|
11 |
+
"gradient_step",
|
12 |
+
"initial_step",
|
13 |
+
"latent_sampling_method",
|
14 |
+
"learn_rate",
|
15 |
+
"log_directory",
|
16 |
+
"model_hash",
|
17 |
+
"model_name",
|
18 |
+
"num_of_dataset_images",
|
19 |
+
"steps",
|
20 |
+
"template_file",
|
21 |
+
"training_height",
|
22 |
+
"training_width",
|
23 |
+
}
|
24 |
+
saved_params_ti = {
|
25 |
+
"embedding_name",
|
26 |
+
"num_vectors_per_token",
|
27 |
+
"save_embedding_every",
|
28 |
+
"save_image_with_stored_embedding",
|
29 |
+
}
|
30 |
+
saved_params_hypernet = {
|
31 |
+
"activation_func",
|
32 |
+
"add_layer_norm",
|
33 |
+
"hypernetwork_name",
|
34 |
+
"layer_structure",
|
35 |
+
"save_hypernetwork_every",
|
36 |
+
"use_dropout",
|
37 |
+
"weight_init",
|
38 |
+
}
|
39 |
+
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
40 |
+
saved_params_previews = {
|
41 |
+
"preview_cfg_scale",
|
42 |
+
"preview_height",
|
43 |
+
"preview_negative_prompt",
|
44 |
+
"preview_prompt",
|
45 |
+
"preview_sampler_index",
|
46 |
+
"preview_seed",
|
47 |
+
"preview_steps",
|
48 |
+
"preview_width",
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
def save_settings_to_file(log_directory, all_params):
|
53 |
+
now = datetime.datetime.now()
|
54 |
+
params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
|
55 |
+
|
56 |
+
keys = saved_params_all
|
57 |
+
if all_params.get('preview_from_txt2img'):
|
58 |
+
keys = keys | saved_params_previews
|
59 |
+
|
60 |
+
params.update({k: v for k, v in all_params.items() if k in keys})
|
61 |
+
|
62 |
+
filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
63 |
+
with open(os.path.join(log_directory, filename), "w") as file:
|
64 |
+
json.dump(params, file, indent=4)
|
modules/textual_inversion/preprocess.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image, ImageOps
|
3 |
+
import math
|
4 |
+
import tqdm
|
5 |
+
|
6 |
+
from modules import paths, shared, images, deepbooru
|
7 |
+
from modules.textual_inversion import autocrop
|
8 |
+
|
9 |
+
|
10 |
+
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
11 |
+
try:
|
12 |
+
if process_caption:
|
13 |
+
shared.interrogator.load()
|
14 |
+
|
15 |
+
if process_caption_deepbooru:
|
16 |
+
deepbooru.model.start()
|
17 |
+
|
18 |
+
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
|
19 |
+
|
20 |
+
finally:
|
21 |
+
|
22 |
+
if process_caption:
|
23 |
+
shared.interrogator.send_blip_to_ram()
|
24 |
+
|
25 |
+
if process_caption_deepbooru:
|
26 |
+
deepbooru.model.stop()
|
27 |
+
|
28 |
+
|
29 |
+
def listfiles(dirname):
|
30 |
+
return os.listdir(dirname)
|
31 |
+
|
32 |
+
|
33 |
+
class PreprocessParams:
|
34 |
+
src = None
|
35 |
+
dstdir = None
|
36 |
+
subindex = 0
|
37 |
+
flip = False
|
38 |
+
process_caption = False
|
39 |
+
process_caption_deepbooru = False
|
40 |
+
preprocess_txt_action = None
|
41 |
+
|
42 |
+
|
43 |
+
def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
|
44 |
+
caption = ""
|
45 |
+
|
46 |
+
if params.process_caption:
|
47 |
+
caption += shared.interrogator.generate_caption(image)
|
48 |
+
|
49 |
+
if params.process_caption_deepbooru:
|
50 |
+
if caption:
|
51 |
+
caption += ", "
|
52 |
+
caption += deepbooru.model.tag_multi(image)
|
53 |
+
|
54 |
+
filename_part = params.src
|
55 |
+
filename_part = os.path.splitext(filename_part)[0]
|
56 |
+
filename_part = os.path.basename(filename_part)
|
57 |
+
|
58 |
+
basename = f"{index:05}-{params.subindex}-{filename_part}"
|
59 |
+
image.save(os.path.join(params.dstdir, f"{basename}.png"))
|
60 |
+
|
61 |
+
if params.preprocess_txt_action == 'prepend' and existing_caption:
|
62 |
+
caption = f"{existing_caption} {caption}"
|
63 |
+
elif params.preprocess_txt_action == 'append' and existing_caption:
|
64 |
+
caption = f"{caption} {existing_caption}"
|
65 |
+
elif params.preprocess_txt_action == 'copy' and existing_caption:
|
66 |
+
caption = existing_caption
|
67 |
+
|
68 |
+
caption = caption.strip()
|
69 |
+
|
70 |
+
if caption:
|
71 |
+
with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
|
72 |
+
file.write(caption)
|
73 |
+
|
74 |
+
params.subindex += 1
|
75 |
+
|
76 |
+
|
77 |
+
def save_pic(image, index, params, existing_caption=None):
|
78 |
+
save_pic_with_caption(image, index, params, existing_caption=existing_caption)
|
79 |
+
|
80 |
+
if params.flip:
|
81 |
+
save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
|
82 |
+
|
83 |
+
|
84 |
+
def split_pic(image, inverse_xy, width, height, overlap_ratio):
|
85 |
+
if inverse_xy:
|
86 |
+
from_w, from_h = image.height, image.width
|
87 |
+
to_w, to_h = height, width
|
88 |
+
else:
|
89 |
+
from_w, from_h = image.width, image.height
|
90 |
+
to_w, to_h = width, height
|
91 |
+
h = from_h * to_w // from_w
|
92 |
+
if inverse_xy:
|
93 |
+
image = image.resize((h, to_w))
|
94 |
+
else:
|
95 |
+
image = image.resize((to_w, h))
|
96 |
+
|
97 |
+
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
|
98 |
+
y_step = (h - to_h) / (split_count - 1)
|
99 |
+
for i in range(split_count):
|
100 |
+
y = int(y_step * i)
|
101 |
+
if inverse_xy:
|
102 |
+
splitted = image.crop((y, 0, y + to_h, to_w))
|
103 |
+
else:
|
104 |
+
splitted = image.crop((0, y, to_w, y + to_h))
|
105 |
+
yield splitted
|
106 |
+
|
107 |
+
# not using torchvision.transforms.CenterCrop because it doesn't allow float regions
|
108 |
+
def center_crop(image: Image, w: int, h: int):
|
109 |
+
iw, ih = image.size
|
110 |
+
if ih / h < iw / w:
|
111 |
+
sw = w * ih / h
|
112 |
+
box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
|
113 |
+
else:
|
114 |
+
sh = h * iw / w
|
115 |
+
box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
|
116 |
+
return image.resize((w, h), Image.Resampling.LANCZOS, box)
|
117 |
+
|
118 |
+
|
119 |
+
def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
|
120 |
+
iw, ih = image.size
|
121 |
+
err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
|
122 |
+
wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
|
123 |
+
if minarea <= w * h <= maxarea and err(w, h) <= threshold),
|
124 |
+
key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
|
125 |
+
default=None
|
126 |
+
)
|
127 |
+
return wh and center_crop(image, *wh)
|
128 |
+
|
129 |
+
|
130 |
+
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
131 |
+
width = process_width
|
132 |
+
height = process_height
|
133 |
+
src = os.path.abspath(process_src)
|
134 |
+
dst = os.path.abspath(process_dst)
|
135 |
+
split_threshold = max(0.0, min(1.0, split_threshold))
|
136 |
+
overlap_ratio = max(0.0, min(0.9, overlap_ratio))
|
137 |
+
|
138 |
+
assert src != dst, 'same directory specified as source and destination'
|
139 |
+
|
140 |
+
os.makedirs(dst, exist_ok=True)
|
141 |
+
|
142 |
+
files = listfiles(src)
|
143 |
+
|
144 |
+
shared.state.job = "preprocess"
|
145 |
+
shared.state.textinfo = "Preprocessing..."
|
146 |
+
shared.state.job_count = len(files)
|
147 |
+
|
148 |
+
params = PreprocessParams()
|
149 |
+
params.dstdir = dst
|
150 |
+
params.flip = process_flip
|
151 |
+
params.process_caption = process_caption
|
152 |
+
params.process_caption_deepbooru = process_caption_deepbooru
|
153 |
+
params.preprocess_txt_action = preprocess_txt_action
|
154 |
+
|
155 |
+
pbar = tqdm.tqdm(files)
|
156 |
+
for index, imagefile in enumerate(pbar):
|
157 |
+
params.subindex = 0
|
158 |
+
filename = os.path.join(src, imagefile)
|
159 |
+
try:
|
160 |
+
img = Image.open(filename)
|
161 |
+
img = ImageOps.exif_transpose(img)
|
162 |
+
img = img.convert("RGB")
|
163 |
+
except Exception:
|
164 |
+
continue
|
165 |
+
|
166 |
+
description = f"Preprocessing [Image {index}/{len(files)}]"
|
167 |
+
pbar.set_description(description)
|
168 |
+
shared.state.textinfo = description
|
169 |
+
|
170 |
+
params.src = filename
|
171 |
+
|
172 |
+
existing_caption = None
|
173 |
+
existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt"
|
174 |
+
if os.path.exists(existing_caption_filename):
|
175 |
+
with open(existing_caption_filename, 'r', encoding="utf8") as file:
|
176 |
+
existing_caption = file.read()
|
177 |
+
|
178 |
+
if shared.state.interrupted:
|
179 |
+
break
|
180 |
+
|
181 |
+
if img.height > img.width:
|
182 |
+
ratio = (img.width * height) / (img.height * width)
|
183 |
+
inverse_xy = False
|
184 |
+
else:
|
185 |
+
ratio = (img.height * width) / (img.width * height)
|
186 |
+
inverse_xy = True
|
187 |
+
|
188 |
+
process_default_resize = True
|
189 |
+
|
190 |
+
if process_split and ratio < 1.0 and ratio <= split_threshold:
|
191 |
+
for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
|
192 |
+
save_pic(splitted, index, params, existing_caption=existing_caption)
|
193 |
+
process_default_resize = False
|
194 |
+
|
195 |
+
if process_focal_crop and img.height != img.width:
|
196 |
+
|
197 |
+
dnn_model_path = None
|
198 |
+
try:
|
199 |
+
dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
|
200 |
+
except Exception as e:
|
201 |
+
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
|
202 |
+
|
203 |
+
autocrop_settings = autocrop.Settings(
|
204 |
+
crop_width = width,
|
205 |
+
crop_height = height,
|
206 |
+
face_points_weight = process_focal_crop_face_weight,
|
207 |
+
entropy_points_weight = process_focal_crop_entropy_weight,
|
208 |
+
corner_points_weight = process_focal_crop_edges_weight,
|
209 |
+
annotate_image = process_focal_crop_debug,
|
210 |
+
dnn_model_path = dnn_model_path,
|
211 |
+
)
|
212 |
+
for focal in autocrop.crop_image(img, autocrop_settings):
|
213 |
+
save_pic(focal, index, params, existing_caption=existing_caption)
|
214 |
+
process_default_resize = False
|
215 |
+
|
216 |
+
if process_multicrop:
|
217 |
+
cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
|
218 |
+
if cropped is not None:
|
219 |
+
save_pic(cropped, index, params, existing_caption=existing_caption)
|
220 |
+
else:
|
221 |
+
print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
|
222 |
+
process_default_resize = False
|
223 |
+
|
224 |
+
if process_keep_original_size:
|
225 |
+
save_pic(img, index, params, existing_caption=existing_caption)
|
226 |
+
process_default_resize = False
|
227 |
+
|
228 |
+
if process_default_resize:
|
229 |
+
img = images.resize_image(1, img, width, height)
|
230 |
+
save_pic(img, index, params, existing_caption=existing_caption)
|
231 |
+
|
232 |
+
shared.state.nextjob()
|
modules/textual_inversion/test_embedding.png
ADDED
modules/textual_inversion/textual_inversion.py
ADDED
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import namedtuple
|
3 |
+
from contextlib import closing
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import tqdm
|
7 |
+
import html
|
8 |
+
import datetime
|
9 |
+
import csv
|
10 |
+
import safetensors.torch
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
from PIL import Image, PngImagePlugin
|
14 |
+
from torch.utils.tensorboard import SummaryWriter
|
15 |
+
|
16 |
+
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
17 |
+
import modules.textual_inversion.dataset
|
18 |
+
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
19 |
+
|
20 |
+
from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
|
21 |
+
from modules.textual_inversion.logging import save_settings_to_file
|
22 |
+
|
23 |
+
|
24 |
+
TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
|
25 |
+
textual_inversion_templates = {}
|
26 |
+
|
27 |
+
|
28 |
+
def list_textual_inversion_templates():
|
29 |
+
textual_inversion_templates.clear()
|
30 |
+
|
31 |
+
for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
|
32 |
+
for fn in fns:
|
33 |
+
path = os.path.join(root, fn)
|
34 |
+
|
35 |
+
textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
|
36 |
+
|
37 |
+
return textual_inversion_templates
|
38 |
+
|
39 |
+
|
40 |
+
class Embedding:
|
41 |
+
def __init__(self, vec, name, step=None):
|
42 |
+
self.vec = vec
|
43 |
+
self.name = name
|
44 |
+
self.step = step
|
45 |
+
self.shape = None
|
46 |
+
self.vectors = 0
|
47 |
+
self.cached_checksum = None
|
48 |
+
self.sd_checkpoint = None
|
49 |
+
self.sd_checkpoint_name = None
|
50 |
+
self.optimizer_state_dict = None
|
51 |
+
self.filename = None
|
52 |
+
self.hash = None
|
53 |
+
self.shorthash = None
|
54 |
+
|
55 |
+
def save(self, filename):
|
56 |
+
embedding_data = {
|
57 |
+
"string_to_token": {"*": 265},
|
58 |
+
"string_to_param": {"*": self.vec},
|
59 |
+
"name": self.name,
|
60 |
+
"step": self.step,
|
61 |
+
"sd_checkpoint": self.sd_checkpoint,
|
62 |
+
"sd_checkpoint_name": self.sd_checkpoint_name,
|
63 |
+
}
|
64 |
+
|
65 |
+
torch.save(embedding_data, filename)
|
66 |
+
|
67 |
+
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
|
68 |
+
optimizer_saved_dict = {
|
69 |
+
'hash': self.checksum(),
|
70 |
+
'optimizer_state_dict': self.optimizer_state_dict,
|
71 |
+
}
|
72 |
+
torch.save(optimizer_saved_dict, f"{filename}.optim")
|
73 |
+
|
74 |
+
def checksum(self):
|
75 |
+
if self.cached_checksum is not None:
|
76 |
+
return self.cached_checksum
|
77 |
+
|
78 |
+
def const_hash(a):
|
79 |
+
r = 0
|
80 |
+
for v in a:
|
81 |
+
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
|
82 |
+
return r
|
83 |
+
|
84 |
+
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
|
85 |
+
return self.cached_checksum
|
86 |
+
|
87 |
+
def set_hash(self, v):
|
88 |
+
self.hash = v
|
89 |
+
self.shorthash = self.hash[0:12]
|
90 |
+
|
91 |
+
|
92 |
+
class DirWithTextualInversionEmbeddings:
|
93 |
+
def __init__(self, path):
|
94 |
+
self.path = path
|
95 |
+
self.mtime = None
|
96 |
+
|
97 |
+
def has_changed(self):
|
98 |
+
if not os.path.isdir(self.path):
|
99 |
+
return False
|
100 |
+
|
101 |
+
mt = os.path.getmtime(self.path)
|
102 |
+
if self.mtime is None or mt > self.mtime:
|
103 |
+
return True
|
104 |
+
|
105 |
+
def update(self):
|
106 |
+
if not os.path.isdir(self.path):
|
107 |
+
return
|
108 |
+
|
109 |
+
self.mtime = os.path.getmtime(self.path)
|
110 |
+
|
111 |
+
|
112 |
+
class EmbeddingDatabase:
|
113 |
+
def __init__(self):
|
114 |
+
self.ids_lookup = {}
|
115 |
+
self.word_embeddings = {}
|
116 |
+
self.skipped_embeddings = {}
|
117 |
+
self.expected_shape = -1
|
118 |
+
self.embedding_dirs = {}
|
119 |
+
self.previously_displayed_embeddings = ()
|
120 |
+
|
121 |
+
def add_embedding_dir(self, path):
|
122 |
+
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
123 |
+
|
124 |
+
def clear_embedding_dirs(self):
|
125 |
+
self.embedding_dirs.clear()
|
126 |
+
|
127 |
+
def register_embedding(self, embedding, model):
|
128 |
+
return self.register_embedding_by_name(embedding, model, embedding.name)
|
129 |
+
|
130 |
+
def register_embedding_by_name(self, embedding, model, name):
|
131 |
+
ids = model.cond_stage_model.tokenize([name])[0]
|
132 |
+
first_id = ids[0]
|
133 |
+
if first_id not in self.ids_lookup:
|
134 |
+
self.ids_lookup[first_id] = []
|
135 |
+
if name in self.word_embeddings:
|
136 |
+
# remove old one from the lookup list
|
137 |
+
lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
|
138 |
+
else:
|
139 |
+
lookup = self.ids_lookup[first_id]
|
140 |
+
if embedding is not None:
|
141 |
+
lookup += [(ids, embedding)]
|
142 |
+
self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
|
143 |
+
if embedding is None:
|
144 |
+
# unregister embedding with specified name
|
145 |
+
if name in self.word_embeddings:
|
146 |
+
del self.word_embeddings[name]
|
147 |
+
if len(self.ids_lookup[first_id])==0:
|
148 |
+
del self.ids_lookup[first_id]
|
149 |
+
return None
|
150 |
+
self.word_embeddings[name] = embedding
|
151 |
+
return embedding
|
152 |
+
|
153 |
+
def get_expected_shape(self):
|
154 |
+
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
155 |
+
return vec.shape[1]
|
156 |
+
|
157 |
+
def load_from_file(self, path, filename):
|
158 |
+
name, ext = os.path.splitext(filename)
|
159 |
+
ext = ext.upper()
|
160 |
+
|
161 |
+
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
162 |
+
_, second_ext = os.path.splitext(name)
|
163 |
+
if second_ext.upper() == '.PREVIEW':
|
164 |
+
return
|
165 |
+
|
166 |
+
embed_image = Image.open(path)
|
167 |
+
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
168 |
+
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
169 |
+
name = data.get('name', name)
|
170 |
+
else:
|
171 |
+
data = extract_image_data_embed(embed_image)
|
172 |
+
if data:
|
173 |
+
name = data.get('name', name)
|
174 |
+
else:
|
175 |
+
# if data is None, means this is not an embeding, just a preview image
|
176 |
+
return
|
177 |
+
elif ext in ['.BIN', '.PT']:
|
178 |
+
data = torch.load(path, map_location="cpu")
|
179 |
+
elif ext in ['.SAFETENSORS']:
|
180 |
+
data = safetensors.torch.load_file(path, device="cpu")
|
181 |
+
else:
|
182 |
+
return
|
183 |
+
|
184 |
+
# textual inversion embeddings
|
185 |
+
if 'string_to_param' in data:
|
186 |
+
param_dict = data['string_to_param']
|
187 |
+
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
188 |
+
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
189 |
+
emb = next(iter(param_dict.items()))[1]
|
190 |
+
# diffuser concepts
|
191 |
+
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
192 |
+
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
193 |
+
|
194 |
+
emb = next(iter(data.values()))
|
195 |
+
if len(emb.shape) == 1:
|
196 |
+
emb = emb.unsqueeze(0)
|
197 |
+
else:
|
198 |
+
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
199 |
+
|
200 |
+
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
201 |
+
embedding = Embedding(vec, name)
|
202 |
+
embedding.step = data.get('step', None)
|
203 |
+
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
204 |
+
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
205 |
+
embedding.vectors = vec.shape[0]
|
206 |
+
embedding.shape = vec.shape[-1]
|
207 |
+
embedding.filename = path
|
208 |
+
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
209 |
+
|
210 |
+
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
211 |
+
self.register_embedding(embedding, shared.sd_model)
|
212 |
+
else:
|
213 |
+
self.skipped_embeddings[name] = embedding
|
214 |
+
|
215 |
+
def load_from_dir(self, embdir):
|
216 |
+
if not os.path.isdir(embdir.path):
|
217 |
+
return
|
218 |
+
|
219 |
+
for root, _, fns in os.walk(embdir.path, followlinks=True):
|
220 |
+
for fn in fns:
|
221 |
+
try:
|
222 |
+
fullfn = os.path.join(root, fn)
|
223 |
+
|
224 |
+
if os.stat(fullfn).st_size == 0:
|
225 |
+
continue
|
226 |
+
|
227 |
+
self.load_from_file(fullfn, fn)
|
228 |
+
except Exception:
|
229 |
+
errors.report(f"Error loading embedding {fn}", exc_info=True)
|
230 |
+
continue
|
231 |
+
|
232 |
+
def load_textual_inversion_embeddings(self, force_reload=False):
|
233 |
+
if not force_reload:
|
234 |
+
need_reload = False
|
235 |
+
for embdir in self.embedding_dirs.values():
|
236 |
+
if embdir.has_changed():
|
237 |
+
need_reload = True
|
238 |
+
break
|
239 |
+
|
240 |
+
if not need_reload:
|
241 |
+
return
|
242 |
+
|
243 |
+
self.ids_lookup.clear()
|
244 |
+
self.word_embeddings.clear()
|
245 |
+
self.skipped_embeddings.clear()
|
246 |
+
self.expected_shape = self.get_expected_shape()
|
247 |
+
|
248 |
+
for embdir in self.embedding_dirs.values():
|
249 |
+
self.load_from_dir(embdir)
|
250 |
+
embdir.update()
|
251 |
+
|
252 |
+
# re-sort word_embeddings because load_from_dir may not load in alphabetic order.
|
253 |
+
# using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.
|
254 |
+
sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
|
255 |
+
self.word_embeddings.clear()
|
256 |
+
self.word_embeddings.update(sorted_word_embeddings)
|
257 |
+
|
258 |
+
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
|
259 |
+
if shared.opts.textual_inversion_print_at_load and self.previously_displayed_embeddings != displayed_embeddings:
|
260 |
+
self.previously_displayed_embeddings = displayed_embeddings
|
261 |
+
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
262 |
+
if self.skipped_embeddings:
|
263 |
+
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
264 |
+
|
265 |
+
def find_embedding_at_position(self, tokens, offset):
|
266 |
+
token = tokens[offset]
|
267 |
+
possible_matches = self.ids_lookup.get(token, None)
|
268 |
+
|
269 |
+
if possible_matches is None:
|
270 |
+
return None, None
|
271 |
+
|
272 |
+
for ids, embedding in possible_matches:
|
273 |
+
if tokens[offset:offset + len(ids)] == ids:
|
274 |
+
return embedding, len(ids)
|
275 |
+
|
276 |
+
return None, None
|
277 |
+
|
278 |
+
|
279 |
+
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
280 |
+
cond_model = shared.sd_model.cond_stage_model
|
281 |
+
|
282 |
+
with devices.autocast():
|
283 |
+
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
|
284 |
+
|
285 |
+
#cond_model expects at least some text, so we provide '*' as backup.
|
286 |
+
embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
|
287 |
+
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
288 |
+
|
289 |
+
#Only copy if we provided an init_text, otherwise keep vectors as zeros
|
290 |
+
if init_text:
|
291 |
+
for i in range(num_vectors_per_token):
|
292 |
+
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
293 |
+
|
294 |
+
# Remove illegal characters from name.
|
295 |
+
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
296 |
+
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
297 |
+
if not overwrite_old:
|
298 |
+
assert not os.path.exists(fn), f"file {fn} already exists"
|
299 |
+
|
300 |
+
embedding = Embedding(vec, name)
|
301 |
+
embedding.step = 0
|
302 |
+
embedding.save(fn)
|
303 |
+
|
304 |
+
return fn
|
305 |
+
|
306 |
+
|
307 |
+
def write_loss(log_directory, filename, step, epoch_len, values):
|
308 |
+
if shared.opts.training_write_csv_every == 0:
|
309 |
+
return
|
310 |
+
|
311 |
+
if step % shared.opts.training_write_csv_every != 0:
|
312 |
+
return
|
313 |
+
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
|
314 |
+
|
315 |
+
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
|
316 |
+
csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
|
317 |
+
|
318 |
+
if write_csv_header:
|
319 |
+
csv_writer.writeheader()
|
320 |
+
|
321 |
+
epoch = (step - 1) // epoch_len
|
322 |
+
epoch_step = (step - 1) % epoch_len
|
323 |
+
|
324 |
+
csv_writer.writerow({
|
325 |
+
"step": step,
|
326 |
+
"epoch": epoch,
|
327 |
+
"epoch_step": epoch_step,
|
328 |
+
**values,
|
329 |
+
})
|
330 |
+
|
331 |
+
def tensorboard_setup(log_directory):
|
332 |
+
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
|
333 |
+
return SummaryWriter(
|
334 |
+
log_dir=os.path.join(log_directory, "tensorboard"),
|
335 |
+
flush_secs=shared.opts.training_tensorboard_flush_every)
|
336 |
+
|
337 |
+
def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
|
338 |
+
tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
|
339 |
+
tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
|
340 |
+
tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
|
341 |
+
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
|
342 |
+
|
343 |
+
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
|
344 |
+
tensorboard_writer.add_scalar(tag=tag,
|
345 |
+
scalar_value=value, global_step=step)
|
346 |
+
|
347 |
+
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
|
348 |
+
# Convert a pil image to a torch tensor
|
349 |
+
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
|
350 |
+
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
351 |
+
len(pil_image.getbands()))
|
352 |
+
img_tensor = img_tensor.permute((2, 0, 1))
|
353 |
+
|
354 |
+
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
|
355 |
+
|
356 |
+
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
357 |
+
assert model_name, f"{name} not selected"
|
358 |
+
assert learn_rate, "Learning rate is empty or 0"
|
359 |
+
assert isinstance(batch_size, int), "Batch size must be integer"
|
360 |
+
assert batch_size > 0, "Batch size must be positive"
|
361 |
+
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
|
362 |
+
assert gradient_step > 0, "Gradient accumulation step must be positive"
|
363 |
+
assert data_root, "Dataset directory is empty"
|
364 |
+
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
365 |
+
assert os.listdir(data_root), "Dataset directory is empty"
|
366 |
+
assert template_filename, "Prompt template file not selected"
|
367 |
+
assert template_file, f"Prompt template file {template_filename} not found"
|
368 |
+
assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
|
369 |
+
assert steps, "Max steps is empty or 0"
|
370 |
+
assert isinstance(steps, int), "Max steps must be integer"
|
371 |
+
assert steps > 0, "Max steps must be positive"
|
372 |
+
assert isinstance(save_model_every, int), "Save {name} must be integer"
|
373 |
+
assert save_model_every >= 0, "Save {name} must be positive or 0"
|
374 |
+
assert isinstance(create_image_every, int), "Create image must be integer"
|
375 |
+
assert create_image_every >= 0, "Create image must be positive or 0"
|
376 |
+
if save_model_every or create_image_every:
|
377 |
+
assert log_directory, "Log directory is empty"
|
378 |
+
|
379 |
+
|
380 |
+
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
381 |
+
save_embedding_every = save_embedding_every or 0
|
382 |
+
create_image_every = create_image_every or 0
|
383 |
+
template_file = textual_inversion_templates.get(template_filename, None)
|
384 |
+
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
385 |
+
template_file = template_file.path
|
386 |
+
|
387 |
+
shared.state.job = "train-embedding"
|
388 |
+
shared.state.textinfo = "Initializing textual inversion training..."
|
389 |
+
shared.state.job_count = steps
|
390 |
+
|
391 |
+
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
392 |
+
|
393 |
+
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
|
394 |
+
unload = shared.opts.unload_models_when_training
|
395 |
+
|
396 |
+
if save_embedding_every > 0:
|
397 |
+
embedding_dir = os.path.join(log_directory, "embeddings")
|
398 |
+
os.makedirs(embedding_dir, exist_ok=True)
|
399 |
+
else:
|
400 |
+
embedding_dir = None
|
401 |
+
|
402 |
+
if create_image_every > 0:
|
403 |
+
images_dir = os.path.join(log_directory, "images")
|
404 |
+
os.makedirs(images_dir, exist_ok=True)
|
405 |
+
else:
|
406 |
+
images_dir = None
|
407 |
+
|
408 |
+
if create_image_every > 0 and save_image_with_stored_embedding:
|
409 |
+
images_embeds_dir = os.path.join(log_directory, "image_embeddings")
|
410 |
+
os.makedirs(images_embeds_dir, exist_ok=True)
|
411 |
+
else:
|
412 |
+
images_embeds_dir = None
|
413 |
+
|
414 |
+
hijack = sd_hijack.model_hijack
|
415 |
+
|
416 |
+
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
417 |
+
checkpoint = sd_models.select_checkpoint()
|
418 |
+
|
419 |
+
initial_step = embedding.step or 0
|
420 |
+
if initial_step >= steps:
|
421 |
+
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
422 |
+
return embedding, filename
|
423 |
+
|
424 |
+
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
425 |
+
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
426 |
+
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
427 |
+
None
|
428 |
+
if clip_grad:
|
429 |
+
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
430 |
+
# dataset loading may take a while, so input validations and early returns should be done before this
|
431 |
+
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
432 |
+
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
433 |
+
|
434 |
+
if shared.opts.training_enable_tensorboard:
|
435 |
+
tensorboard_writer = tensorboard_setup(log_directory)
|
436 |
+
|
437 |
+
pin_memory = shared.opts.pin_memory
|
438 |
+
|
439 |
+
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
|
440 |
+
|
441 |
+
if shared.opts.save_training_settings_to_txt:
|
442 |
+
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
|
443 |
+
|
444 |
+
latent_sampling_method = ds.latent_sampling_method
|
445 |
+
|
446 |
+
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
447 |
+
|
448 |
+
if unload:
|
449 |
+
shared.parallel_processing_allowed = False
|
450 |
+
shared.sd_model.first_stage_model.to(devices.cpu)
|
451 |
+
|
452 |
+
embedding.vec.requires_grad = True
|
453 |
+
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
454 |
+
if shared.opts.save_optimizer_state:
|
455 |
+
optimizer_state_dict = None
|
456 |
+
if os.path.exists(f"{filename}.optim"):
|
457 |
+
optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
|
458 |
+
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
459 |
+
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
460 |
+
|
461 |
+
if optimizer_state_dict is not None:
|
462 |
+
optimizer.load_state_dict(optimizer_state_dict)
|
463 |
+
print("Loaded existing optimizer from checkpoint")
|
464 |
+
else:
|
465 |
+
print("No saved optimizer exists in checkpoint")
|
466 |
+
|
467 |
+
scaler = torch.cuda.amp.GradScaler()
|
468 |
+
|
469 |
+
batch_size = ds.batch_size
|
470 |
+
gradient_step = ds.gradient_step
|
471 |
+
# n steps = batch_size * gradient_step * n image processed
|
472 |
+
steps_per_epoch = len(ds) // batch_size // gradient_step
|
473 |
+
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
474 |
+
loss_step = 0
|
475 |
+
_loss_step = 0 #internal
|
476 |
+
|
477 |
+
last_saved_file = "<none>"
|
478 |
+
last_saved_image = "<none>"
|
479 |
+
forced_filename = "<none>"
|
480 |
+
embedding_yet_to_be_embedded = False
|
481 |
+
|
482 |
+
is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
483 |
+
img_c = None
|
484 |
+
|
485 |
+
pbar = tqdm.tqdm(total=steps - initial_step)
|
486 |
+
try:
|
487 |
+
sd_hijack_checkpoint.add()
|
488 |
+
|
489 |
+
for _ in range((steps-initial_step) * gradient_step):
|
490 |
+
if scheduler.finished:
|
491 |
+
break
|
492 |
+
if shared.state.interrupted:
|
493 |
+
break
|
494 |
+
for j, batch in enumerate(dl):
|
495 |
+
# works as a drop_last=True for gradient accumulation
|
496 |
+
if j == max_steps_per_epoch:
|
497 |
+
break
|
498 |
+
scheduler.apply(optimizer, embedding.step)
|
499 |
+
if scheduler.finished:
|
500 |
+
break
|
501 |
+
if shared.state.interrupted:
|
502 |
+
break
|
503 |
+
|
504 |
+
if clip_grad:
|
505 |
+
clip_grad_sched.step(embedding.step)
|
506 |
+
|
507 |
+
with devices.autocast():
|
508 |
+
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
509 |
+
if use_weight:
|
510 |
+
w = batch.weight.to(devices.device, non_blocking=pin_memory)
|
511 |
+
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
512 |
+
|
513 |
+
if is_training_inpainting_model:
|
514 |
+
if img_c is None:
|
515 |
+
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
|
516 |
+
|
517 |
+
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
518 |
+
else:
|
519 |
+
cond = c
|
520 |
+
|
521 |
+
if use_weight:
|
522 |
+
loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step
|
523 |
+
del w
|
524 |
+
else:
|
525 |
+
loss = shared.sd_model.forward(x, cond)[0] / gradient_step
|
526 |
+
del x
|
527 |
+
|
528 |
+
_loss_step += loss.item()
|
529 |
+
scaler.scale(loss).backward()
|
530 |
+
|
531 |
+
# go back until we reach gradient accumulation steps
|
532 |
+
if (j + 1) % gradient_step != 0:
|
533 |
+
continue
|
534 |
+
|
535 |
+
if clip_grad:
|
536 |
+
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|
537 |
+
|
538 |
+
scaler.step(optimizer)
|
539 |
+
scaler.update()
|
540 |
+
embedding.step += 1
|
541 |
+
pbar.update()
|
542 |
+
optimizer.zero_grad(set_to_none=True)
|
543 |
+
loss_step = _loss_step
|
544 |
+
_loss_step = 0
|
545 |
+
|
546 |
+
steps_done = embedding.step + 1
|
547 |
+
|
548 |
+
epoch_num = embedding.step // steps_per_epoch
|
549 |
+
epoch_step = embedding.step % steps_per_epoch
|
550 |
+
|
551 |
+
description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
|
552 |
+
pbar.set_description(description)
|
553 |
+
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
554 |
+
# Before saving, change name to match current checkpoint.
|
555 |
+
embedding_name_every = f'{embedding_name}-{steps_done}'
|
556 |
+
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
557 |
+
save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
558 |
+
embedding_yet_to_be_embedded = True
|
559 |
+
|
560 |
+
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
561 |
+
"loss": f"{loss_step:.7f}",
|
562 |
+
"learn_rate": scheduler.learn_rate
|
563 |
+
})
|
564 |
+
|
565 |
+
if images_dir is not None and steps_done % create_image_every == 0:
|
566 |
+
forced_filename = f'{embedding_name}-{steps_done}'
|
567 |
+
last_saved_image = os.path.join(images_dir, forced_filename)
|
568 |
+
|
569 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
570 |
+
|
571 |
+
p = processing.StableDiffusionProcessingTxt2Img(
|
572 |
+
sd_model=shared.sd_model,
|
573 |
+
do_not_save_grid=True,
|
574 |
+
do_not_save_samples=True,
|
575 |
+
do_not_reload_embeddings=True,
|
576 |
+
)
|
577 |
+
|
578 |
+
if preview_from_txt2img:
|
579 |
+
p.prompt = preview_prompt
|
580 |
+
p.negative_prompt = preview_negative_prompt
|
581 |
+
p.steps = preview_steps
|
582 |
+
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
583 |
+
p.cfg_scale = preview_cfg_scale
|
584 |
+
p.seed = preview_seed
|
585 |
+
p.width = preview_width
|
586 |
+
p.height = preview_height
|
587 |
+
else:
|
588 |
+
p.prompt = batch.cond_text[0]
|
589 |
+
p.steps = 20
|
590 |
+
p.width = training_width
|
591 |
+
p.height = training_height
|
592 |
+
|
593 |
+
preview_text = p.prompt
|
594 |
+
|
595 |
+
with closing(p):
|
596 |
+
processed = processing.process_images(p)
|
597 |
+
image = processed.images[0] if len(processed.images) > 0 else None
|
598 |
+
|
599 |
+
if unload:
|
600 |
+
shared.sd_model.first_stage_model.to(devices.cpu)
|
601 |
+
|
602 |
+
if image is not None:
|
603 |
+
shared.state.assign_current_image(image)
|
604 |
+
|
605 |
+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
606 |
+
last_saved_image += f", prompt: {preview_text}"
|
607 |
+
|
608 |
+
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
609 |
+
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
|
610 |
+
|
611 |
+
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
612 |
+
|
613 |
+
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
614 |
+
|
615 |
+
info = PngImagePlugin.PngInfo()
|
616 |
+
data = torch.load(last_saved_file)
|
617 |
+
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
618 |
+
|
619 |
+
title = f"<{data.get('name', '???')}>"
|
620 |
+
|
621 |
+
try:
|
622 |
+
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
623 |
+
except Exception:
|
624 |
+
vectorSize = '?'
|
625 |
+
|
626 |
+
checkpoint = sd_models.select_checkpoint()
|
627 |
+
footer_left = checkpoint.model_name
|
628 |
+
footer_mid = f'[{checkpoint.shorthash}]'
|
629 |
+
footer_right = f'{vectorSize}v {steps_done}s'
|
630 |
+
|
631 |
+
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
632 |
+
captioned_image = insert_image_data_embed(captioned_image, data)
|
633 |
+
|
634 |
+
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
635 |
+
embedding_yet_to_be_embedded = False
|
636 |
+
|
637 |
+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
638 |
+
last_saved_image += f", prompt: {preview_text}"
|
639 |
+
|
640 |
+
shared.state.job_no = embedding.step
|
641 |
+
|
642 |
+
shared.state.textinfo = f"""
|
643 |
+
<p>
|
644 |
+
Loss: {loss_step:.7f}<br/>
|
645 |
+
Step: {steps_done}<br/>
|
646 |
+
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
647 |
+
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
648 |
+
Last saved image: {html.escape(last_saved_image)}<br/>
|
649 |
+
</p>
|
650 |
+
"""
|
651 |
+
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
652 |
+
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
653 |
+
except Exception:
|
654 |
+
errors.report("Error training embedding", exc_info=True)
|
655 |
+
finally:
|
656 |
+
pbar.leave = False
|
657 |
+
pbar.close()
|
658 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
659 |
+
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
660 |
+
sd_hijack_checkpoint.remove()
|
661 |
+
|
662 |
+
return embedding, filename
|
663 |
+
|
664 |
+
|
665 |
+
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
666 |
+
old_embedding_name = embedding.name
|
667 |
+
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
668 |
+
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
669 |
+
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
|
670 |
+
try:
|
671 |
+
embedding.sd_checkpoint = checkpoint.shorthash
|
672 |
+
embedding.sd_checkpoint_name = checkpoint.model_name
|
673 |
+
if remove_cached_checksum:
|
674 |
+
embedding.cached_checksum = None
|
675 |
+
embedding.name = embedding_name
|
676 |
+
embedding.optimizer_state_dict = optimizer.state_dict()
|
677 |
+
embedding.save(filename)
|
678 |
+
except:
|
679 |
+
embedding.sd_checkpoint = old_sd_checkpoint
|
680 |
+
embedding.sd_checkpoint_name = old_sd_checkpoint_name
|
681 |
+
embedding.name = old_embedding_name
|
682 |
+
embedding.cached_checksum = old_cached_checksum
|
683 |
+
raise
|
modules/textual_inversion/ui.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
import modules.textual_inversion.textual_inversion
|
6 |
+
import modules.textual_inversion.preprocess
|
7 |
+
from modules import sd_hijack, shared
|
8 |
+
|
9 |
+
|
10 |
+
def create_embedding(name, initialization_text, nvpt, overwrite_old):
|
11 |
+
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
|
12 |
+
|
13 |
+
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
14 |
+
|
15 |
+
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
|
16 |
+
|
17 |
+
|
18 |
+
def preprocess(*args):
|
19 |
+
modules.textual_inversion.preprocess.preprocess(*args)
|
20 |
+
|
21 |
+
return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
|
22 |
+
|
23 |
+
|
24 |
+
def train_embedding(*args):
|
25 |
+
|
26 |
+
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
|
27 |
+
|
28 |
+
apply_optimizations = shared.opts.training_xattention_optimizations
|
29 |
+
try:
|
30 |
+
if not apply_optimizations:
|
31 |
+
sd_hijack.undo_optimizations()
|
32 |
+
|
33 |
+
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
|
34 |
+
|
35 |
+
res = f"""
|
36 |
+
Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
|
37 |
+
Embedding saved to {html.escape(filename)}
|
38 |
+
"""
|
39 |
+
return res, ""
|
40 |
+
except Exception:
|
41 |
+
raise
|
42 |
+
finally:
|
43 |
+
if not apply_optimizations:
|
44 |
+
sd_hijack.apply_optimizations()
|
45 |
+
|
modules/timer.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
|
5 |
+
class TimerSubcategory:
|
6 |
+
def __init__(self, timer, category):
|
7 |
+
self.timer = timer
|
8 |
+
self.category = category
|
9 |
+
self.start = None
|
10 |
+
self.original_base_category = timer.base_category
|
11 |
+
|
12 |
+
def __enter__(self):
|
13 |
+
self.start = time.time()
|
14 |
+
self.timer.base_category = self.original_base_category + self.category + "/"
|
15 |
+
self.timer.subcategory_level += 1
|
16 |
+
|
17 |
+
if self.timer.print_log:
|
18 |
+
print(f"{' ' * self.timer.subcategory_level}{self.category}:")
|
19 |
+
|
20 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
21 |
+
elapsed_for_subcategroy = time.time() - self.start
|
22 |
+
self.timer.base_category = self.original_base_category
|
23 |
+
self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
|
24 |
+
self.timer.subcategory_level -= 1
|
25 |
+
self.timer.record(self.category, disable_log=True)
|
26 |
+
|
27 |
+
|
28 |
+
class Timer:
|
29 |
+
def __init__(self, print_log=False):
|
30 |
+
self.start = time.time()
|
31 |
+
self.records = {}
|
32 |
+
self.total = 0
|
33 |
+
self.base_category = ''
|
34 |
+
self.print_log = print_log
|
35 |
+
self.subcategory_level = 0
|
36 |
+
|
37 |
+
def elapsed(self):
|
38 |
+
end = time.time()
|
39 |
+
res = end - self.start
|
40 |
+
self.start = end
|
41 |
+
return res
|
42 |
+
|
43 |
+
def add_time_to_record(self, category, amount):
|
44 |
+
if category not in self.records:
|
45 |
+
self.records[category] = 0
|
46 |
+
|
47 |
+
self.records[category] += amount
|
48 |
+
|
49 |
+
def record(self, category, extra_time=0, disable_log=False):
|
50 |
+
e = self.elapsed()
|
51 |
+
|
52 |
+
self.add_time_to_record(self.base_category + category, e + extra_time)
|
53 |
+
|
54 |
+
self.total += e + extra_time
|
55 |
+
|
56 |
+
if self.print_log and not disable_log:
|
57 |
+
print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s")
|
58 |
+
|
59 |
+
def subcategory(self, name):
|
60 |
+
self.elapsed()
|
61 |
+
|
62 |
+
subcat = TimerSubcategory(self, name)
|
63 |
+
return subcat
|
64 |
+
|
65 |
+
def summary(self):
|
66 |
+
res = f"{self.total:.1f}s"
|
67 |
+
|
68 |
+
additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category]
|
69 |
+
if not additions:
|
70 |
+
return res
|
71 |
+
|
72 |
+
res += " ("
|
73 |
+
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
|
74 |
+
res += ")"
|
75 |
+
|
76 |
+
return res
|
77 |
+
|
78 |
+
def dump(self):
|
79 |
+
return {'total': self.total, 'records': self.records}
|
80 |
+
|
81 |
+
def reset(self):
|
82 |
+
self.__init__()
|
83 |
+
|
84 |
+
|
85 |
+
parser = argparse.ArgumentParser(add_help=False)
|
86 |
+
parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup")
|
87 |
+
args = parser.parse_known_args()[0]
|
88 |
+
|
89 |
+
startup_timer = Timer(print_log=args.log_startup)
|
90 |
+
|
91 |
+
startup_record = None
|
modules/txt2img.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import closing
|
2 |
+
|
3 |
+
import modules.scripts
|
4 |
+
from modules import sd_samplers, processing
|
5 |
+
from modules.generation_parameters_copypaste import create_override_settings_dict
|
6 |
+
from modules.shared import opts, cmd_opts
|
7 |
+
import modules.shared as shared
|
8 |
+
from modules.ui import plaintext_to_html
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
|
12 |
+
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
13 |
+
override_settings = create_override_settings_dict(override_settings_texts)
|
14 |
+
|
15 |
+
p = processing.StableDiffusionProcessingTxt2Img(
|
16 |
+
sd_model=shared.sd_model,
|
17 |
+
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
18 |
+
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
|
19 |
+
prompt=prompt,
|
20 |
+
styles=prompt_styles,
|
21 |
+
negative_prompt=negative_prompt,
|
22 |
+
seed=seed,
|
23 |
+
subseed=subseed,
|
24 |
+
subseed_strength=subseed_strength,
|
25 |
+
seed_resize_from_h=seed_resize_from_h,
|
26 |
+
seed_resize_from_w=seed_resize_from_w,
|
27 |
+
seed_enable_extras=seed_enable_extras,
|
28 |
+
sampler_name=sd_samplers.samplers[sampler_index].name,
|
29 |
+
batch_size=batch_size,
|
30 |
+
n_iter=n_iter,
|
31 |
+
steps=steps,
|
32 |
+
cfg_scale=cfg_scale,
|
33 |
+
width=width,
|
34 |
+
height=height,
|
35 |
+
restore_faces=restore_faces,
|
36 |
+
tiling=tiling,
|
37 |
+
enable_hr=enable_hr,
|
38 |
+
denoising_strength=denoising_strength if enable_hr else None,
|
39 |
+
hr_scale=hr_scale,
|
40 |
+
hr_upscaler=hr_upscaler,
|
41 |
+
hr_second_pass_steps=hr_second_pass_steps,
|
42 |
+
hr_resize_x=hr_resize_x,
|
43 |
+
hr_resize_y=hr_resize_y,
|
44 |
+
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
|
45 |
+
hr_prompt=hr_prompt,
|
46 |
+
hr_negative_prompt=hr_negative_prompt,
|
47 |
+
override_settings=override_settings,
|
48 |
+
)
|
49 |
+
|
50 |
+
p.scripts = modules.scripts.scripts_txt2img
|
51 |
+
p.script_args = args
|
52 |
+
|
53 |
+
p.user = request.username
|
54 |
+
|
55 |
+
if cmd_opts.enable_console_prompts:
|
56 |
+
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
57 |
+
|
58 |
+
with closing(p):
|
59 |
+
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
60 |
+
|
61 |
+
if processed is None:
|
62 |
+
processed = processing.process_images(p)
|
63 |
+
|
64 |
+
shared.total_tqdm.clear()
|
65 |
+
|
66 |
+
generation_info_js = processed.js()
|
67 |
+
if opts.samples_log_stdout:
|
68 |
+
print(generation_info_js)
|
69 |
+
|
70 |
+
if opts.do_not_show_images:
|
71 |
+
processed.images = []
|
72 |
+
|
73 |
+
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
modules/ui.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
modules/ui_common.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
import platform
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import subprocess as sp
|
9 |
+
|
10 |
+
from modules import call_queue, shared
|
11 |
+
from modules.generation_parameters_copypaste import image_from_url_text
|
12 |
+
import modules.images
|
13 |
+
from modules.ui_components import ToolButton
|
14 |
+
|
15 |
+
|
16 |
+
folder_symbol = '\U0001f4c2' # 📂
|
17 |
+
refresh_symbol = '\U0001f504' # 🔄
|
18 |
+
|
19 |
+
|
20 |
+
def update_generation_info(generation_info, html_info, img_index):
|
21 |
+
try:
|
22 |
+
generation_info = json.loads(generation_info)
|
23 |
+
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
|
24 |
+
return html_info, gr.update()
|
25 |
+
return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
|
26 |
+
except Exception:
|
27 |
+
pass
|
28 |
+
# if the json parse or anything else fails, just return the old html_info
|
29 |
+
return html_info, gr.update()
|
30 |
+
|
31 |
+
|
32 |
+
def plaintext_to_html(text, classname=None):
|
33 |
+
content = "<br>\n".join(html.escape(x) for x in text.split('\n'))
|
34 |
+
|
35 |
+
return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
|
36 |
+
|
37 |
+
|
38 |
+
def save_files(js_data, images, do_make_zip, index):
|
39 |
+
import csv
|
40 |
+
filenames = []
|
41 |
+
fullfns = []
|
42 |
+
|
43 |
+
#quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
|
44 |
+
class MyObject:
|
45 |
+
def __init__(self, d=None):
|
46 |
+
if d is not None:
|
47 |
+
for key, value in d.items():
|
48 |
+
setattr(self, key, value)
|
49 |
+
|
50 |
+
data = json.loads(js_data)
|
51 |
+
|
52 |
+
p = MyObject(data)
|
53 |
+
path = shared.opts.outdir_save
|
54 |
+
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
|
55 |
+
extension: str = shared.opts.samples_format
|
56 |
+
start_index = 0
|
57 |
+
only_one = False
|
58 |
+
|
59 |
+
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
60 |
+
only_one = True
|
61 |
+
images = [images[index]]
|
62 |
+
start_index = index
|
63 |
+
|
64 |
+
os.makedirs(shared.opts.outdir_save, exist_ok=True)
|
65 |
+
|
66 |
+
with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
67 |
+
at_start = file.tell() == 0
|
68 |
+
writer = csv.writer(file)
|
69 |
+
if at_start:
|
70 |
+
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
71 |
+
|
72 |
+
for image_index, filedata in enumerate(images, start_index):
|
73 |
+
image = image_from_url_text(filedata)
|
74 |
+
|
75 |
+
is_grid = image_index < p.index_of_first_image
|
76 |
+
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
77 |
+
|
78 |
+
p.batch_index = image_index-1
|
79 |
+
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
80 |
+
|
81 |
+
filename = os.path.relpath(fullfn, path)
|
82 |
+
filenames.append(filename)
|
83 |
+
fullfns.append(fullfn)
|
84 |
+
if txt_fullfn:
|
85 |
+
filenames.append(os.path.basename(txt_fullfn))
|
86 |
+
fullfns.append(txt_fullfn)
|
87 |
+
|
88 |
+
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
89 |
+
|
90 |
+
# Make Zip
|
91 |
+
if do_make_zip:
|
92 |
+
zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
|
93 |
+
namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
|
94 |
+
zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
|
95 |
+
zip_filepath = os.path.join(path, f"{zip_filename}.zip")
|
96 |
+
|
97 |
+
from zipfile import ZipFile
|
98 |
+
with ZipFile(zip_filepath, "w") as zip_file:
|
99 |
+
for i in range(len(fullfns)):
|
100 |
+
with open(fullfns[i], mode="rb") as f:
|
101 |
+
zip_file.writestr(filenames[i], f.read())
|
102 |
+
fullfns.insert(0, zip_filepath)
|
103 |
+
|
104 |
+
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
105 |
+
|
106 |
+
|
107 |
+
def create_output_panel(tabname, outdir):
|
108 |
+
from modules import shared
|
109 |
+
import modules.generation_parameters_copypaste as parameters_copypaste
|
110 |
+
|
111 |
+
def open_folder(f):
|
112 |
+
if not os.path.exists(f):
|
113 |
+
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
|
114 |
+
return
|
115 |
+
elif not os.path.isdir(f):
|
116 |
+
print(f"""
|
117 |
+
WARNING
|
118 |
+
An open_folder request was made with an argument that is not a folder.
|
119 |
+
This could be an error or a malicious attempt to run code on your computer.
|
120 |
+
Requested path was: {f}
|
121 |
+
""", file=sys.stderr)
|
122 |
+
return
|
123 |
+
|
124 |
+
if not shared.cmd_opts.hide_ui_dir_config:
|
125 |
+
path = os.path.normpath(f)
|
126 |
+
if platform.system() == "Windows":
|
127 |
+
os.startfile(path)
|
128 |
+
elif platform.system() == "Darwin":
|
129 |
+
sp.Popen(["open", path])
|
130 |
+
elif "microsoft-standard-WSL2" in platform.uname().release:
|
131 |
+
sp.Popen(["wsl-open", path])
|
132 |
+
else:
|
133 |
+
sp.Popen(["xdg-open", path])
|
134 |
+
|
135 |
+
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
136 |
+
|
137 |
+
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
138 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
|
139 |
+
|
140 |
+
generation_info = None
|
141 |
+
with gr.Column():
|
142 |
+
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
143 |
+
open_folder_button = gr.Button(folder_symbol, elem_id=f'open_folder_{tabname}', visible=not shared.cmd_opts.hide_ui_dir_config)
|
144 |
+
|
145 |
+
if tabname != "extras":
|
146 |
+
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
147 |
+
save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
|
148 |
+
|
149 |
+
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
|
150 |
+
|
151 |
+
open_folder_button.click(
|
152 |
+
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
|
153 |
+
inputs=[],
|
154 |
+
outputs=[],
|
155 |
+
)
|
156 |
+
|
157 |
+
if tabname != "extras":
|
158 |
+
with gr.Row():
|
159 |
+
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
160 |
+
|
161 |
+
with gr.Accordion("Generation Info", open=False):
|
162 |
+
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
163 |
+
html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
|
164 |
+
|
165 |
+
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
166 |
+
if tabname == 'txt2img' or tabname == 'img2img':
|
167 |
+
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
|
168 |
+
generation_info_button.click(
|
169 |
+
fn=update_generation_info,
|
170 |
+
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
|
171 |
+
inputs=[generation_info, html_info, html_info],
|
172 |
+
outputs=[html_info, html_info],
|
173 |
+
show_progress=False,
|
174 |
+
)
|
175 |
+
|
176 |
+
save.click(
|
177 |
+
fn=call_queue.wrap_gradio_call(save_files),
|
178 |
+
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
|
179 |
+
inputs=[
|
180 |
+
generation_info,
|
181 |
+
result_gallery,
|
182 |
+
html_info,
|
183 |
+
html_info,
|
184 |
+
],
|
185 |
+
outputs=[
|
186 |
+
download_files,
|
187 |
+
html_log,
|
188 |
+
],
|
189 |
+
show_progress=False,
|
190 |
+
)
|
191 |
+
|
192 |
+
save_zip.click(
|
193 |
+
fn=call_queue.wrap_gradio_call(save_files),
|
194 |
+
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
|
195 |
+
inputs=[
|
196 |
+
generation_info,
|
197 |
+
result_gallery,
|
198 |
+
html_info,
|
199 |
+
html_info,
|
200 |
+
],
|
201 |
+
outputs=[
|
202 |
+
download_files,
|
203 |
+
html_log,
|
204 |
+
]
|
205 |
+
)
|
206 |
+
|
207 |
+
else:
|
208 |
+
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
|
209 |
+
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
210 |
+
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
211 |
+
|
212 |
+
paste_field_names = []
|
213 |
+
if tabname == "txt2img":
|
214 |
+
paste_field_names = modules.scripts.scripts_txt2img.paste_field_names
|
215 |
+
elif tabname == "img2img":
|
216 |
+
paste_field_names = modules.scripts.scripts_img2img.paste_field_names
|
217 |
+
|
218 |
+
for paste_tabname, paste_button in buttons.items():
|
219 |
+
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
220 |
+
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
|
221 |
+
paste_field_names=paste_field_names
|
222 |
+
))
|
223 |
+
|
224 |
+
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
225 |
+
|
226 |
+
|
227 |
+
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
228 |
+
def refresh():
|
229 |
+
refresh_method()
|
230 |
+
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
231 |
+
|
232 |
+
for k, v in args.items():
|
233 |
+
setattr(refresh_component, k, v)
|
234 |
+
|
235 |
+
return gr.update(**(args or {}))
|
236 |
+
|
237 |
+
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
|
238 |
+
refresh_button.click(
|
239 |
+
fn=refresh,
|
240 |
+
inputs=[],
|
241 |
+
outputs=[refresh_component]
|
242 |
+
)
|
243 |
+
return refresh_button
|
244 |
+
|
modules/ui_components.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
|
4 |
+
class FormComponent:
|
5 |
+
def get_expected_parent(self):
|
6 |
+
return gr.components.Form
|
7 |
+
|
8 |
+
|
9 |
+
gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
|
10 |
+
|
11 |
+
|
12 |
+
class ToolButton(FormComponent, gr.Button):
|
13 |
+
"""Small button with single emoji as text, fits inside gradio forms"""
|
14 |
+
|
15 |
+
def __init__(self, *args, **kwargs):
|
16 |
+
classes = kwargs.pop("elem_classes", [])
|
17 |
+
super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
|
18 |
+
|
19 |
+
def get_block_name(self):
|
20 |
+
return "button"
|
21 |
+
|
22 |
+
|
23 |
+
class FormRow(FormComponent, gr.Row):
|
24 |
+
"""Same as gr.Row but fits inside gradio forms"""
|
25 |
+
|
26 |
+
def get_block_name(self):
|
27 |
+
return "row"
|
28 |
+
|
29 |
+
|
30 |
+
class FormColumn(FormComponent, gr.Column):
|
31 |
+
"""Same as gr.Column but fits inside gradio forms"""
|
32 |
+
|
33 |
+
def get_block_name(self):
|
34 |
+
return "column"
|
35 |
+
|
36 |
+
|
37 |
+
class FormGroup(FormComponent, gr.Group):
|
38 |
+
"""Same as gr.Row but fits inside gradio forms"""
|
39 |
+
|
40 |
+
def get_block_name(self):
|
41 |
+
return "group"
|
42 |
+
|
43 |
+
|
44 |
+
class FormHTML(FormComponent, gr.HTML):
|
45 |
+
"""Same as gr.HTML but fits inside gradio forms"""
|
46 |
+
|
47 |
+
def get_block_name(self):
|
48 |
+
return "html"
|
49 |
+
|
50 |
+
|
51 |
+
class FormColorPicker(FormComponent, gr.ColorPicker):
|
52 |
+
"""Same as gr.ColorPicker but fits inside gradio forms"""
|
53 |
+
|
54 |
+
def get_block_name(self):
|
55 |
+
return "colorpicker"
|
56 |
+
|
57 |
+
|
58 |
+
class DropdownMulti(FormComponent, gr.Dropdown):
|
59 |
+
"""Same as gr.Dropdown but always multiselect"""
|
60 |
+
def __init__(self, **kwargs):
|
61 |
+
super().__init__(multiselect=True, **kwargs)
|
62 |
+
|
63 |
+
def get_block_name(self):
|
64 |
+
return "dropdown"
|
65 |
+
|
66 |
+
|
67 |
+
class DropdownEditable(FormComponent, gr.Dropdown):
|
68 |
+
"""Same as gr.Dropdown but allows editing value"""
|
69 |
+
def __init__(self, **kwargs):
|
70 |
+
super().__init__(allow_custom_value=True, **kwargs)
|
71 |
+
|
72 |
+
def get_block_name(self):
|
73 |
+
return "dropdown"
|
74 |
+
|
modules/ui_extensions.py
ADDED
@@ -0,0 +1,651 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
import git
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import html
|
11 |
+
import shutil
|
12 |
+
import errno
|
13 |
+
|
14 |
+
from modules import extensions, shared, paths, config_states, errors, restart
|
15 |
+
from modules.paths_internal import config_states_dir
|
16 |
+
from modules.call_queue import wrap_gradio_gpu_call
|
17 |
+
|
18 |
+
available_extensions = {"extensions": []}
|
19 |
+
STYLE_PRIMARY = ' style="color: var(--primary-400)"'
|
20 |
+
|
21 |
+
|
22 |
+
def check_access():
|
23 |
+
assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
|
24 |
+
|
25 |
+
|
26 |
+
def apply_and_restart(disable_list, update_list, disable_all):
|
27 |
+
check_access()
|
28 |
+
|
29 |
+
disabled = json.loads(disable_list)
|
30 |
+
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
|
31 |
+
|
32 |
+
update = json.loads(update_list)
|
33 |
+
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
|
34 |
+
|
35 |
+
if update:
|
36 |
+
save_config_state("Backup (pre-update)")
|
37 |
+
|
38 |
+
update = set(update)
|
39 |
+
|
40 |
+
for ext in extensions.extensions:
|
41 |
+
if ext.name not in update:
|
42 |
+
continue
|
43 |
+
|
44 |
+
try:
|
45 |
+
ext.fetch_and_reset_hard()
|
46 |
+
except Exception:
|
47 |
+
errors.report(f"Error getting updates for {ext.name}", exc_info=True)
|
48 |
+
|
49 |
+
shared.opts.disabled_extensions = disabled
|
50 |
+
shared.opts.disable_all_extensions = disable_all
|
51 |
+
shared.opts.save(shared.config_filename)
|
52 |
+
|
53 |
+
if restart.is_restartable():
|
54 |
+
restart.restart_program()
|
55 |
+
else:
|
56 |
+
restart.stop_program()
|
57 |
+
|
58 |
+
|
59 |
+
def save_config_state(name):
|
60 |
+
current_config_state = config_states.get_config()
|
61 |
+
if not name:
|
62 |
+
name = "Config"
|
63 |
+
current_config_state["name"] = name
|
64 |
+
timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S')
|
65 |
+
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
|
66 |
+
print(f"Saving backup of webui/extension state to {filename}.")
|
67 |
+
with open(filename, "w", encoding="utf-8") as f:
|
68 |
+
json.dump(current_config_state, f)
|
69 |
+
config_states.list_config_states()
|
70 |
+
new_value = next(iter(config_states.all_config_states.keys()), "Current")
|
71 |
+
new_choices = ["Current"] + list(config_states.all_config_states.keys())
|
72 |
+
return gr.Dropdown.update(value=new_value, choices=new_choices), f"<span>Saved current webui/extension state to \"{filename}\"</span>"
|
73 |
+
|
74 |
+
|
75 |
+
def restore_config_state(confirmed, config_state_name, restore_type):
|
76 |
+
if config_state_name == "Current":
|
77 |
+
return "<span>Select a config to restore from.</span>"
|
78 |
+
if not confirmed:
|
79 |
+
return "<span>Cancelled.</span>"
|
80 |
+
|
81 |
+
check_access()
|
82 |
+
|
83 |
+
config_state = config_states.all_config_states[config_state_name]
|
84 |
+
|
85 |
+
print(f"*** Restoring webui state from backup: {restore_type} ***")
|
86 |
+
|
87 |
+
if restore_type == "extensions" or restore_type == "both":
|
88 |
+
shared.opts.restore_config_state_file = config_state["filepath"]
|
89 |
+
shared.opts.save(shared.config_filename)
|
90 |
+
|
91 |
+
if restore_type == "webui" or restore_type == "both":
|
92 |
+
config_states.restore_webui_config(config_state)
|
93 |
+
|
94 |
+
shared.state.request_restart()
|
95 |
+
|
96 |
+
return ""
|
97 |
+
|
98 |
+
|
99 |
+
def check_updates(id_task, disable_list):
|
100 |
+
check_access()
|
101 |
+
|
102 |
+
disabled = json.loads(disable_list)
|
103 |
+
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
|
104 |
+
|
105 |
+
exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
|
106 |
+
shared.state.job_count = len(exts)
|
107 |
+
|
108 |
+
for ext in exts:
|
109 |
+
shared.state.textinfo = ext.name
|
110 |
+
|
111 |
+
try:
|
112 |
+
ext.check_updates()
|
113 |
+
except FileNotFoundError as e:
|
114 |
+
if 'FETCH_HEAD' not in str(e):
|
115 |
+
raise
|
116 |
+
except Exception:
|
117 |
+
errors.report(f"Error checking updates for {ext.name}", exc_info=True)
|
118 |
+
|
119 |
+
shared.state.nextjob()
|
120 |
+
|
121 |
+
return extension_table(), ""
|
122 |
+
|
123 |
+
|
124 |
+
def make_commit_link(commit_hash, remote, text=None):
|
125 |
+
if text is None:
|
126 |
+
text = commit_hash[:8]
|
127 |
+
if remote.startswith("https://github.com/"):
|
128 |
+
if remote.endswith(".git"):
|
129 |
+
remote = remote[:-4]
|
130 |
+
href = remote + "/commit/" + commit_hash
|
131 |
+
return f'<a href="{href}" target="_blank">{text}</a>'
|
132 |
+
else:
|
133 |
+
return text
|
134 |
+
|
135 |
+
|
136 |
+
def extension_table():
|
137 |
+
code = f"""<!-- {time.time()} -->
|
138 |
+
<table id="extensions">
|
139 |
+
<thead>
|
140 |
+
<tr>
|
141 |
+
<th>
|
142 |
+
<input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
|
143 |
+
<abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
|
144 |
+
</th>
|
145 |
+
<th>URL</th>
|
146 |
+
<th>Branch</th>
|
147 |
+
<th>Version</th>
|
148 |
+
<th>Date</th>
|
149 |
+
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
|
150 |
+
</tr>
|
151 |
+
</thead>
|
152 |
+
<tbody>
|
153 |
+
"""
|
154 |
+
|
155 |
+
for ext in extensions.extensions:
|
156 |
+
ext: extensions.Extension
|
157 |
+
ext.read_info_from_repo()
|
158 |
+
|
159 |
+
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
|
160 |
+
|
161 |
+
if ext.can_update:
|
162 |
+
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
|
163 |
+
else:
|
164 |
+
ext_status = ext.status
|
165 |
+
|
166 |
+
style = ""
|
167 |
+
if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
|
168 |
+
style = STYLE_PRIMARY
|
169 |
+
|
170 |
+
version_link = ext.version
|
171 |
+
if ext.commit_hash and ext.remote:
|
172 |
+
version_link = make_commit_link(ext.commit_hash, ext.remote, ext.version)
|
173 |
+
|
174 |
+
code += f"""
|
175 |
+
<tr>
|
176 |
+
<td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
|
177 |
+
<td>{remote}</td>
|
178 |
+
<td>{ext.branch}</td>
|
179 |
+
<td>{version_link}</td>
|
180 |
+
<td>{time.asctime(time.gmtime(ext.commit_date))}</td>
|
181 |
+
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
182 |
+
</tr>
|
183 |
+
"""
|
184 |
+
|
185 |
+
code += """
|
186 |
+
</tbody>
|
187 |
+
</table>
|
188 |
+
"""
|
189 |
+
|
190 |
+
return code
|
191 |
+
|
192 |
+
|
193 |
+
def update_config_states_table(state_name):
|
194 |
+
if state_name == "Current":
|
195 |
+
config_state = config_states.get_config()
|
196 |
+
else:
|
197 |
+
config_state = config_states.all_config_states[state_name]
|
198 |
+
|
199 |
+
config_name = config_state.get("name", "Config")
|
200 |
+
created_date = time.asctime(time.gmtime(config_state["created_at"]))
|
201 |
+
filepath = config_state.get("filepath", "<unknown>")
|
202 |
+
|
203 |
+
code = f"""<!-- {time.time()} -->"""
|
204 |
+
|
205 |
+
webui_remote = config_state["webui"]["remote"] or ""
|
206 |
+
webui_branch = config_state["webui"]["branch"]
|
207 |
+
webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
|
208 |
+
webui_commit_date = config_state["webui"]["commit_date"]
|
209 |
+
if webui_commit_date:
|
210 |
+
webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
|
211 |
+
else:
|
212 |
+
webui_commit_date = "<unknown>"
|
213 |
+
|
214 |
+
remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
|
215 |
+
commit_link = make_commit_link(webui_commit_hash, webui_remote)
|
216 |
+
date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
|
217 |
+
|
218 |
+
current_webui = config_states.get_webui_config()
|
219 |
+
|
220 |
+
style_remote = ""
|
221 |
+
style_branch = ""
|
222 |
+
style_commit = ""
|
223 |
+
if current_webui["remote"] != webui_remote:
|
224 |
+
style_remote = STYLE_PRIMARY
|
225 |
+
if current_webui["branch"] != webui_branch:
|
226 |
+
style_branch = STYLE_PRIMARY
|
227 |
+
if current_webui["commit_hash"] != webui_commit_hash:
|
228 |
+
style_commit = STYLE_PRIMARY
|
229 |
+
|
230 |
+
code += f"""<h2>Config Backup: {config_name}</h2>
|
231 |
+
<div><b>Filepath:</b> {filepath}</div>
|
232 |
+
<div><b>Created at:</b> {created_date}</div>"""
|
233 |
+
|
234 |
+
code += f"""<h2>WebUI State</h2>
|
235 |
+
<table id="config_state_webui">
|
236 |
+
<thead>
|
237 |
+
<tr>
|
238 |
+
<th>URL</th>
|
239 |
+
<th>Branch</th>
|
240 |
+
<th>Commit</th>
|
241 |
+
<th>Date</th>
|
242 |
+
</tr>
|
243 |
+
</thead>
|
244 |
+
<tbody>
|
245 |
+
<tr>
|
246 |
+
<td><label{style_remote}>{remote}</label></td>
|
247 |
+
<td><label{style_branch}>{webui_branch}</label></td>
|
248 |
+
<td><label{style_commit}>{commit_link}</label></td>
|
249 |
+
<td><label{style_commit}>{date_link}</label></td>
|
250 |
+
</tr>
|
251 |
+
</tbody>
|
252 |
+
</table>
|
253 |
+
"""
|
254 |
+
|
255 |
+
code += """<h2>Extension State</h2>
|
256 |
+
<table id="config_state_extensions">
|
257 |
+
<thead>
|
258 |
+
<tr>
|
259 |
+
<th>Extension</th>
|
260 |
+
<th>URL</th>
|
261 |
+
<th>Branch</th>
|
262 |
+
<th>Commit</th>
|
263 |
+
<th>Date</th>
|
264 |
+
</tr>
|
265 |
+
</thead>
|
266 |
+
<tbody>
|
267 |
+
"""
|
268 |
+
|
269 |
+
ext_map = {ext.name: ext for ext in extensions.extensions}
|
270 |
+
|
271 |
+
for ext_name, ext_conf in config_state["extensions"].items():
|
272 |
+
ext_remote = ext_conf["remote"] or ""
|
273 |
+
ext_branch = ext_conf["branch"] or "<unknown>"
|
274 |
+
ext_enabled = ext_conf["enabled"]
|
275 |
+
ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
|
276 |
+
ext_commit_date = ext_conf["commit_date"]
|
277 |
+
if ext_commit_date:
|
278 |
+
ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
|
279 |
+
else:
|
280 |
+
ext_commit_date = "<unknown>"
|
281 |
+
|
282 |
+
remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
|
283 |
+
commit_link = make_commit_link(ext_commit_hash, ext_remote)
|
284 |
+
date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
|
285 |
+
|
286 |
+
style_enabled = ""
|
287 |
+
style_remote = ""
|
288 |
+
style_branch = ""
|
289 |
+
style_commit = ""
|
290 |
+
if ext_name in ext_map:
|
291 |
+
current_ext = ext_map[ext_name]
|
292 |
+
current_ext.read_info_from_repo()
|
293 |
+
if current_ext.enabled != ext_enabled:
|
294 |
+
style_enabled = STYLE_PRIMARY
|
295 |
+
if current_ext.remote != ext_remote:
|
296 |
+
style_remote = STYLE_PRIMARY
|
297 |
+
if current_ext.branch != ext_branch:
|
298 |
+
style_branch = STYLE_PRIMARY
|
299 |
+
if current_ext.commit_hash != ext_commit_hash:
|
300 |
+
style_commit = STYLE_PRIMARY
|
301 |
+
|
302 |
+
code += f"""
|
303 |
+
<tr>
|
304 |
+
<td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
|
305 |
+
<td><label{style_remote}>{remote}</label></td>
|
306 |
+
<td><label{style_branch}>{ext_branch}</label></td>
|
307 |
+
<td><label{style_commit}>{commit_link}</label></td>
|
308 |
+
<td><label{style_commit}>{date_link}</label></td>
|
309 |
+
</tr>
|
310 |
+
"""
|
311 |
+
|
312 |
+
code += """
|
313 |
+
</tbody>
|
314 |
+
</table>
|
315 |
+
"""
|
316 |
+
|
317 |
+
return code
|
318 |
+
|
319 |
+
|
320 |
+
def normalize_git_url(url):
|
321 |
+
if url is None:
|
322 |
+
return ""
|
323 |
+
|
324 |
+
url = url.replace(".git", "")
|
325 |
+
return url
|
326 |
+
|
327 |
+
|
328 |
+
def install_extension_from_url(dirname, url, branch_name=None):
|
329 |
+
check_access()
|
330 |
+
|
331 |
+
if isinstance(dirname, str):
|
332 |
+
dirname = dirname.strip()
|
333 |
+
if isinstance(url, str):
|
334 |
+
url = url.strip()
|
335 |
+
|
336 |
+
assert url, 'No URL specified'
|
337 |
+
|
338 |
+
if dirname is None or dirname == "":
|
339 |
+
*parts, last_part = url.split('/')
|
340 |
+
last_part = normalize_git_url(last_part)
|
341 |
+
|
342 |
+
dirname = last_part
|
343 |
+
|
344 |
+
target_dir = os.path.join(extensions.extensions_dir, dirname)
|
345 |
+
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
|
346 |
+
|
347 |
+
normalized_url = normalize_git_url(url)
|
348 |
+
if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url):
|
349 |
+
raise Exception(f'Extension with this URL is already installed: {url}')
|
350 |
+
|
351 |
+
tmpdir = os.path.join(paths.data_path, "tmp", dirname)
|
352 |
+
|
353 |
+
try:
|
354 |
+
shutil.rmtree(tmpdir, True)
|
355 |
+
if not branch_name:
|
356 |
+
# if no branch is specified, use the default branch
|
357 |
+
with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo:
|
358 |
+
repo.remote().fetch()
|
359 |
+
for submodule in repo.submodules:
|
360 |
+
submodule.update()
|
361 |
+
else:
|
362 |
+
with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo:
|
363 |
+
repo.remote().fetch()
|
364 |
+
for submodule in repo.submodules:
|
365 |
+
submodule.update()
|
366 |
+
try:
|
367 |
+
os.rename(tmpdir, target_dir)
|
368 |
+
except OSError as err:
|
369 |
+
if err.errno == errno.EXDEV:
|
370 |
+
# Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
|
371 |
+
# Since we can't use a rename, do the slower but more versitile shutil.move()
|
372 |
+
shutil.move(tmpdir, target_dir)
|
373 |
+
else:
|
374 |
+
# Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
|
375 |
+
raise err
|
376 |
+
|
377 |
+
import launch
|
378 |
+
launch.run_extension_installer(target_dir)
|
379 |
+
|
380 |
+
extensions.list_extensions()
|
381 |
+
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
|
382 |
+
finally:
|
383 |
+
shutil.rmtree(tmpdir, True)
|
384 |
+
|
385 |
+
|
386 |
+
def install_extension_from_index(url, hide_tags, sort_column, filter_text):
|
387 |
+
ext_table, message = install_extension_from_url(None, url)
|
388 |
+
|
389 |
+
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
|
390 |
+
|
391 |
+
return code, ext_table, message, ''
|
392 |
+
|
393 |
+
|
394 |
+
def refresh_available_extensions(url, hide_tags, sort_column):
|
395 |
+
global available_extensions
|
396 |
+
|
397 |
+
import urllib.request
|
398 |
+
with urllib.request.urlopen(url) as response:
|
399 |
+
text = response.read()
|
400 |
+
|
401 |
+
available_extensions = json.loads(text)
|
402 |
+
|
403 |
+
code, tags = refresh_available_extensions_from_data(hide_tags, sort_column)
|
404 |
+
|
405 |
+
return url, code, gr.CheckboxGroup.update(choices=tags), '', ''
|
406 |
+
|
407 |
+
|
408 |
+
def refresh_available_extensions_for_tags(hide_tags, sort_column, filter_text):
|
409 |
+
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
|
410 |
+
|
411 |
+
return code, ''
|
412 |
+
|
413 |
+
|
414 |
+
def search_extensions(filter_text, hide_tags, sort_column):
|
415 |
+
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
|
416 |
+
|
417 |
+
return code, ''
|
418 |
+
|
419 |
+
|
420 |
+
sort_ordering = [
|
421 |
+
# (reverse, order_by_function)
|
422 |
+
(True, lambda x: x.get('added', 'z')),
|
423 |
+
(False, lambda x: x.get('added', 'z')),
|
424 |
+
(False, lambda x: x.get('name', 'z')),
|
425 |
+
(True, lambda x: x.get('name', 'z')),
|
426 |
+
(False, lambda x: 'z'),
|
427 |
+
(True, lambda x: x.get('commit_time', '')),
|
428 |
+
(True, lambda x: x.get('created_at', '')),
|
429 |
+
(True, lambda x: x.get('stars', 0)),
|
430 |
+
]
|
431 |
+
|
432 |
+
|
433 |
+
def get_date(info: dict, key):
|
434 |
+
try:
|
435 |
+
return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d")
|
436 |
+
except (ValueError, TypeError):
|
437 |
+
return ''
|
438 |
+
|
439 |
+
|
440 |
+
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
|
441 |
+
extlist = available_extensions["extensions"]
|
442 |
+
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
443 |
+
|
444 |
+
tags = available_extensions.get("tags", {})
|
445 |
+
tags_to_hide = set(hide_tags)
|
446 |
+
hidden = 0
|
447 |
+
|
448 |
+
code = f"""<!-- {time.time()} -->
|
449 |
+
<table id="available_extensions">
|
450 |
+
<thead>
|
451 |
+
<tr>
|
452 |
+
<th>Extension</th>
|
453 |
+
<th>Description</th>
|
454 |
+
<th>Action</th>
|
455 |
+
</tr>
|
456 |
+
</thead>
|
457 |
+
<tbody>
|
458 |
+
"""
|
459 |
+
|
460 |
+
sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0]
|
461 |
+
|
462 |
+
for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
|
463 |
+
name = ext.get("name", "noname")
|
464 |
+
stars = int(ext.get("stars", 0))
|
465 |
+
added = ext.get('added', 'unknown')
|
466 |
+
update_time = get_date(ext, 'commit_time')
|
467 |
+
create_time = get_date(ext, 'created_at')
|
468 |
+
url = ext.get("url", None)
|
469 |
+
description = ext.get("description", "")
|
470 |
+
extension_tags = ext.get("tags", [])
|
471 |
+
|
472 |
+
if url is None:
|
473 |
+
continue
|
474 |
+
|
475 |
+
existing = installed_extension_urls.get(normalize_git_url(url), None)
|
476 |
+
extension_tags = extension_tags + ["installed"] if existing else extension_tags
|
477 |
+
|
478 |
+
if any(x for x in extension_tags if x in tags_to_hide):
|
479 |
+
hidden += 1
|
480 |
+
continue
|
481 |
+
|
482 |
+
if filter_text and filter_text.strip():
|
483 |
+
if filter_text.lower() not in html.escape(name).lower() and filter_text.lower() not in html.escape(description).lower():
|
484 |
+
hidden += 1
|
485 |
+
continue
|
486 |
+
|
487 |
+
install_code = f"""<button onclick="install_extension_from_index(this, '{html.escape(url)}')" {"disabled=disabled" if existing else ""} class="lg secondary gradio-button custom-button">{"Install" if not existing else "Installed"}</button>"""
|
488 |
+
|
489 |
+
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
|
490 |
+
|
491 |
+
code += f"""
|
492 |
+
<tr>
|
493 |
+
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
|
494 |
+
<td>{html.escape(description)}<p class="info">
|
495 |
+
<span class="date_added">Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
|
496 |
+
<td>{install_code}</td>
|
497 |
+
</tr>
|
498 |
+
|
499 |
+
"""
|
500 |
+
|
501 |
+
for tag in [x for x in extension_tags if x not in tags]:
|
502 |
+
tags[tag] = tag
|
503 |
+
|
504 |
+
code += """
|
505 |
+
</tbody>
|
506 |
+
</table>
|
507 |
+
"""
|
508 |
+
|
509 |
+
if hidden > 0:
|
510 |
+
code += f"<p>Extension hidden: {hidden}</p>"
|
511 |
+
|
512 |
+
return code, list(tags)
|
513 |
+
|
514 |
+
|
515 |
+
def preload_extensions_git_metadata():
|
516 |
+
for extension in extensions.extensions:
|
517 |
+
extension.read_info_from_repo()
|
518 |
+
|
519 |
+
|
520 |
+
def create_ui():
|
521 |
+
import modules.ui
|
522 |
+
|
523 |
+
config_states.list_config_states()
|
524 |
+
|
525 |
+
threading.Thread(target=preload_extensions_git_metadata).start()
|
526 |
+
|
527 |
+
with gr.Blocks(analytics_enabled=False) as ui:
|
528 |
+
with gr.Tabs(elem_id="tabs_extensions"):
|
529 |
+
with gr.TabItem("Installed", id="installed"):
|
530 |
+
|
531 |
+
with gr.Row(elem_id="extensions_installed_top"):
|
532 |
+
apply_label = ("Apply and restart UI" if restart.is_restartable() else "Apply and quit")
|
533 |
+
apply = gr.Button(value=apply_label, variant="primary")
|
534 |
+
check = gr.Button(value="Check for updates")
|
535 |
+
extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
|
536 |
+
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
|
537 |
+
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
|
538 |
+
|
539 |
+
html = ""
|
540 |
+
if shared.opts.disable_all_extensions != "none":
|
541 |
+
html = """
|
542 |
+
<span style="color: var(--primary-400);">
|
543 |
+
"Disable all extensions" was set, change it to "none" to load all extensions again
|
544 |
+
</span>
|
545 |
+
"""
|
546 |
+
info = gr.HTML(html)
|
547 |
+
extensions_table = gr.HTML('Loading...')
|
548 |
+
ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
|
549 |
+
|
550 |
+
apply.click(
|
551 |
+
fn=apply_and_restart,
|
552 |
+
_js="extensions_apply",
|
553 |
+
inputs=[extensions_disabled_list, extensions_update_list, extensions_disable_all],
|
554 |
+
outputs=[],
|
555 |
+
)
|
556 |
+
|
557 |
+
check.click(
|
558 |
+
fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]),
|
559 |
+
_js="extensions_check",
|
560 |
+
inputs=[info, extensions_disabled_list],
|
561 |
+
outputs=[extensions_table, info],
|
562 |
+
)
|
563 |
+
|
564 |
+
with gr.TabItem("Available", id="available"):
|
565 |
+
with gr.Row():
|
566 |
+
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
|
567 |
+
extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
|
568 |
+
available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False)
|
569 |
+
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
|
570 |
+
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
571 |
+
|
572 |
+
with gr.Row():
|
573 |
+
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
574 |
+
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
|
575 |
+
|
576 |
+
with gr.Row():
|
577 |
+
search_extensions_text = gr.Text(label="Search").style(container=False)
|
578 |
+
|
579 |
+
install_result = gr.HTML()
|
580 |
+
available_extensions_table = gr.HTML()
|
581 |
+
|
582 |
+
refresh_available_extensions_button.click(
|
583 |
+
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
|
584 |
+
inputs=[available_extensions_index, hide_tags, sort_column],
|
585 |
+
outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
|
586 |
+
)
|
587 |
+
|
588 |
+
install_extension_button.click(
|
589 |
+
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
|
590 |
+
inputs=[extension_to_install, hide_tags, sort_column, search_extensions_text],
|
591 |
+
outputs=[available_extensions_table, extensions_table, install_result],
|
592 |
+
)
|
593 |
+
|
594 |
+
search_extensions_text.change(
|
595 |
+
fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update()]),
|
596 |
+
inputs=[search_extensions_text, hide_tags, sort_column],
|
597 |
+
outputs=[available_extensions_table, install_result],
|
598 |
+
)
|
599 |
+
|
600 |
+
hide_tags.change(
|
601 |
+
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
|
602 |
+
inputs=[hide_tags, sort_column, search_extensions_text],
|
603 |
+
outputs=[available_extensions_table, install_result]
|
604 |
+
)
|
605 |
+
|
606 |
+
sort_column.change(
|
607 |
+
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
|
608 |
+
inputs=[hide_tags, sort_column, search_extensions_text],
|
609 |
+
outputs=[available_extensions_table, install_result]
|
610 |
+
)
|
611 |
+
|
612 |
+
with gr.TabItem("Install from URL", id="install_from_url"):
|
613 |
+
install_url = gr.Text(label="URL for extension's git repository")
|
614 |
+
install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch")
|
615 |
+
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
|
616 |
+
install_button = gr.Button(value="Install", variant="primary")
|
617 |
+
install_result = gr.HTML(elem_id="extension_install_result")
|
618 |
+
|
619 |
+
install_button.click(
|
620 |
+
fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),
|
621 |
+
inputs=[install_dirname, install_url, install_branch],
|
622 |
+
outputs=[install_url, extensions_table, install_result],
|
623 |
+
)
|
624 |
+
|
625 |
+
with gr.TabItem("Backup/Restore"):
|
626 |
+
with gr.Row(elem_id="extensions_backup_top_row"):
|
627 |
+
config_states_list = gr.Dropdown(label="Saved Configs", elem_id="extension_backup_saved_configs", value="Current", choices=["Current"] + list(config_states.all_config_states.keys()))
|
628 |
+
modules.ui.create_refresh_button(config_states_list, config_states.list_config_states, lambda: {"choices": ["Current"] + list(config_states.all_config_states.keys())}, "refresh_config_states")
|
629 |
+
config_restore_type = gr.Radio(label="State to restore", choices=["extensions", "webui", "both"], value="extensions", elem_id="extension_backup_restore_type")
|
630 |
+
config_restore_button = gr.Button(value="Restore Selected Config", variant="primary", elem_id="extension_backup_restore")
|
631 |
+
with gr.Row(elem_id="extensions_backup_top_row2"):
|
632 |
+
config_save_name = gr.Textbox("", placeholder="Config Name", show_label=False)
|
633 |
+
config_save_button = gr.Button(value="Save Current Config")
|
634 |
+
|
635 |
+
config_states_info = gr.HTML("")
|
636 |
+
config_states_table = gr.HTML("Loading...")
|
637 |
+
ui.load(fn=update_config_states_table, inputs=[config_states_list], outputs=[config_states_table])
|
638 |
+
|
639 |
+
config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
|
640 |
+
|
641 |
+
dummy_component = gr.Label(visible=False)
|
642 |
+
config_restore_button.click(fn=restore_config_state, _js="config_state_confirm_restore", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info])
|
643 |
+
|
644 |
+
config_states_list.change(
|
645 |
+
fn=update_config_states_table,
|
646 |
+
inputs=[config_states_list],
|
647 |
+
outputs=[config_states_table],
|
648 |
+
)
|
649 |
+
|
650 |
+
|
651 |
+
return ui
|
modules/ui_extra_networks.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import urllib.parse
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from modules import shared, ui_extra_networks_user_metadata, errors
|
6 |
+
from modules.images import read_info_from_image, save_image_with_geninfo
|
7 |
+
from modules.ui import up_down_symbol
|
8 |
+
import gradio as gr
|
9 |
+
import json
|
10 |
+
import html
|
11 |
+
|
12 |
+
from modules.ui_components import ToolButton
|
13 |
+
from fastapi.exceptions import HTTPException
|
14 |
+
|
15 |
+
from modules.generation_parameters_copypaste import image_from_url_text
|
16 |
+
from modules.ui_components import ToolButton
|
17 |
+
|
18 |
+
extra_pages = []
|
19 |
+
allowed_dirs = set()
|
20 |
+
refresh_symbol = '\U0001f504' # 🔄
|
21 |
+
#clear_symbol = '\U0001F5D9' # 🗙
|
22 |
+
|
23 |
+
def register_page(page):
|
24 |
+
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
25 |
+
|
26 |
+
extra_pages.append(page)
|
27 |
+
allowed_dirs.clear()
|
28 |
+
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
|
29 |
+
|
30 |
+
|
31 |
+
def fetch_file(filename: str = ""):
|
32 |
+
from starlette.responses import FileResponse
|
33 |
+
|
34 |
+
if not os.path.isfile(filename):
|
35 |
+
raise HTTPException(status_code=404, detail="File not found")
|
36 |
+
|
37 |
+
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
|
38 |
+
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
39 |
+
|
40 |
+
ext = os.path.splitext(filename)[1].lower()
|
41 |
+
if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
|
42 |
+
raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
|
43 |
+
|
44 |
+
# would profit from returning 304
|
45 |
+
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
46 |
+
|
47 |
+
|
48 |
+
def get_metadata(page: str = "", item: str = ""):
|
49 |
+
from starlette.responses import JSONResponse
|
50 |
+
|
51 |
+
page = next(iter([x for x in extra_pages if x.name == page]), None)
|
52 |
+
if page is None:
|
53 |
+
return JSONResponse({})
|
54 |
+
|
55 |
+
metadata = page.metadata.get(item)
|
56 |
+
if metadata is None:
|
57 |
+
return JSONResponse({})
|
58 |
+
|
59 |
+
return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
|
60 |
+
|
61 |
+
|
62 |
+
def get_single_card(page: str = "", tabname: str = "", name: str = ""):
|
63 |
+
from starlette.responses import JSONResponse
|
64 |
+
|
65 |
+
page = next(iter([x for x in extra_pages if x.name == page]), None)
|
66 |
+
|
67 |
+
try:
|
68 |
+
item = page.create_item(name, enable_filter=False)
|
69 |
+
page.items[name] = item
|
70 |
+
except Exception as e:
|
71 |
+
errors.display(e, "creating item for extra network")
|
72 |
+
item = page.items.get(name)
|
73 |
+
|
74 |
+
page.read_user_metadata(item)
|
75 |
+
item_html = page.create_html_for_item(item, tabname)
|
76 |
+
|
77 |
+
return JSONResponse({"html": item_html})
|
78 |
+
|
79 |
+
|
80 |
+
def add_pages_to_demo(app):
|
81 |
+
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
|
82 |
+
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
|
83 |
+
app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
|
84 |
+
|
85 |
+
|
86 |
+
def quote_js(s):
|
87 |
+
s = s.replace('\\', '\\\\')
|
88 |
+
s = s.replace('"', '\\"')
|
89 |
+
return f'"{s}"'
|
90 |
+
|
91 |
+
|
92 |
+
class ExtraNetworksPage:
|
93 |
+
def __init__(self, title):
|
94 |
+
self.title = title
|
95 |
+
self.name = title.lower()
|
96 |
+
self.id_page = self.name.replace(" ", "_")
|
97 |
+
self.card_page = shared.html("extra-networks-card.html")
|
98 |
+
self.allow_negative_prompt = False
|
99 |
+
self.metadata = {}
|
100 |
+
self.items = {}
|
101 |
+
|
102 |
+
def refresh(self):
|
103 |
+
pass
|
104 |
+
|
105 |
+
def read_user_metadata(self, item):
|
106 |
+
filename = item.get("filename", None)
|
107 |
+
basename, ext = os.path.splitext(filename)
|
108 |
+
metadata_filename = basename + '.json'
|
109 |
+
|
110 |
+
metadata = {}
|
111 |
+
try:
|
112 |
+
if os.path.isfile(metadata_filename):
|
113 |
+
with open(metadata_filename, "r", encoding="utf8") as file:
|
114 |
+
metadata = json.load(file)
|
115 |
+
except Exception as e:
|
116 |
+
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
117 |
+
|
118 |
+
desc = metadata.get("description", None)
|
119 |
+
if desc is not None:
|
120 |
+
item["description"] = desc
|
121 |
+
|
122 |
+
item["user_metadata"] = metadata
|
123 |
+
|
124 |
+
def link_preview(self, filename):
|
125 |
+
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
|
126 |
+
mtime = os.path.getmtime(filename)
|
127 |
+
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
|
128 |
+
|
129 |
+
def search_terms_from_path(self, filename, possible_directories=None):
|
130 |
+
abspath = os.path.abspath(filename)
|
131 |
+
|
132 |
+
for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
|
133 |
+
parentdir = os.path.abspath(parentdir)
|
134 |
+
if abspath.startswith(parentdir):
|
135 |
+
return abspath[len(parentdir):].replace('\\', '/')
|
136 |
+
|
137 |
+
return ""
|
138 |
+
|
139 |
+
|
140 |
+
def create_html(self, tabname):
|
141 |
+
view = "cards" #shared.opts.extra_networks_default_view
|
142 |
+
items_html = ''
|
143 |
+
|
144 |
+
self.metadata = {}
|
145 |
+
|
146 |
+
subdirs = {}
|
147 |
+
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
148 |
+
for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
|
149 |
+
for dirname in sorted(dirs, key=shared.natural_sort_key):
|
150 |
+
x = os.path.join(root, dirname)
|
151 |
+
|
152 |
+
if not os.path.isdir(x):
|
153 |
+
continue
|
154 |
+
|
155 |
+
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
156 |
+
while subdir.startswith("/"):
|
157 |
+
subdir = subdir[1:]
|
158 |
+
|
159 |
+
is_empty = len(os.listdir(x)) == 0
|
160 |
+
if not is_empty and not subdir.endswith("/"):
|
161 |
+
subdir = subdir + "/"
|
162 |
+
|
163 |
+
if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
|
164 |
+
continue
|
165 |
+
|
166 |
+
subdirs[subdir] = 1
|
167 |
+
|
168 |
+
if subdirs:
|
169 |
+
subdirs = {"": 1, **subdirs}
|
170 |
+
|
171 |
+
|
172 |
+
#<option value='{html.escape(subdir if subdir!="" else "all")}'>{html.escape(subdir if subdir!="" else "all")}</option>
|
173 |
+
subdirs_html = "".join([f"""
|
174 |
+
<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
175 |
+
{html.escape(subdir if subdir!="" else "all")}
|
176 |
+
</button>
|
177 |
+
""" for subdir in subdirs])
|
178 |
+
|
179 |
+
self.items = {x["name"]: x for x in self.list_items()}
|
180 |
+
for item in self.items.values():
|
181 |
+
metadata = item.get("metadata")
|
182 |
+
if metadata:
|
183 |
+
self.metadata[item["name"]] = metadata
|
184 |
+
|
185 |
+
if "user_metadata" not in item:
|
186 |
+
self.read_user_metadata(item)
|
187 |
+
|
188 |
+
items_html += self.create_html_for_item(item, tabname)
|
189 |
+
|
190 |
+
if items_html == '':
|
191 |
+
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
192 |
+
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
193 |
+
|
194 |
+
self_name_id = self.name.replace(" ", "_")
|
195 |
+
|
196 |
+
# <select onchange='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
197 |
+
# {subdirs_html}
|
198 |
+
# </select>
|
199 |
+
res = f"""
|
200 |
+
<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-cards'>
|
201 |
+
{subdirs_html}
|
202 |
+
</div>
|
203 |
+
<div id='{tabname}_{self_name_id}_cards' class='extra-network-cards'>
|
204 |
+
{items_html}
|
205 |
+
</div>
|
206 |
+
"""
|
207 |
+
|
208 |
+
return res
|
209 |
+
|
210 |
+
def create_item(self, name, index=None):
|
211 |
+
raise NotImplementedError()
|
212 |
+
|
213 |
+
def list_items(self):
|
214 |
+
raise NotImplementedError()
|
215 |
+
|
216 |
+
def allowed_directories_for_previews(self):
|
217 |
+
return []
|
218 |
+
|
219 |
+
def create_html_for_item(self, item, tabname):
|
220 |
+
"""
|
221 |
+
Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown.
|
222 |
+
"""
|
223 |
+
|
224 |
+
preview = item.get("preview", None)
|
225 |
+
|
226 |
+
onclick = item.get("onclick", None)
|
227 |
+
if onclick is None:
|
228 |
+
onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
229 |
+
|
230 |
+
#height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
|
231 |
+
#width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
|
232 |
+
background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
|
233 |
+
|
234 |
+
metadata_button = ""
|
235 |
+
metadata = item.get("metadata")
|
236 |
+
if metadata:
|
237 |
+
metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>"
|
238 |
+
|
239 |
+
edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>"
|
240 |
+
|
241 |
+
local_path = ""
|
242 |
+
filename = item.get("filename", "")
|
243 |
+
|
244 |
+
for reldir in self.allowed_directories_for_previews():
|
245 |
+
absdir = os.path.abspath(reldir)
|
246 |
+
if filename.startswith(absdir):
|
247 |
+
local_path = filename[len(absdir):]
|
248 |
+
|
249 |
+
# if this is true, the item must not be shown in the default view, and must instead only be
|
250 |
+
# shown when searching for it
|
251 |
+
|
252 |
+
if shared.opts.extra_networks_hidden_models == "Always":
|
253 |
+
search_only = False
|
254 |
+
else:
|
255 |
+
search_only = "/." in local_path or "\\." in local_path
|
256 |
+
|
257 |
+
if search_only and shared.opts.extra_networks_hidden_models == "Never":
|
258 |
+
return ""
|
259 |
+
|
260 |
+
sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
|
261 |
+
|
262 |
+
args = {
|
263 |
+
#"background_image": background_image,
|
264 |
+
#"style": f"'display: none; {height}{width}'",
|
265 |
+
"preview_image": html.escape(preview) if preview else './file=html/card-no-preview.png',
|
266 |
+
"prompt": item.get("prompt", None),
|
267 |
+
"tabname": quote_js(tabname),
|
268 |
+
"local_preview": quote_js(item["local_preview"]),
|
269 |
+
"name": item["name"],
|
270 |
+
"description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
|
271 |
+
"card_clicked": onclick,
|
272 |
+
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
|
273 |
+
"search_term": item.get("search_term", ""),
|
274 |
+
"metadata_button": metadata_button,
|
275 |
+
"edit_button": edit_button,
|
276 |
+
"search_only": " search_only" if search_only else "",
|
277 |
+
"sort_keys": sort_keys,
|
278 |
+
}
|
279 |
+
|
280 |
+
return self.card_page.format(**args)
|
281 |
+
|
282 |
+
def get_sort_keys(self, path):
|
283 |
+
"""
|
284 |
+
List of default keys used for sorting in the UI.
|
285 |
+
"""
|
286 |
+
pth = Path(path)
|
287 |
+
stat = pth.stat()
|
288 |
+
return {
|
289 |
+
"date_created": int(stat.st_ctime or 0),
|
290 |
+
"date_modified": int(stat.st_mtime or 0),
|
291 |
+
"name": pth.name.lower(),
|
292 |
+
}
|
293 |
+
|
294 |
+
def find_preview(self, path):
|
295 |
+
"""
|
296 |
+
Find a preview PNG for a given path (without extension) and call link_preview on it.
|
297 |
+
"""
|
298 |
+
|
299 |
+
preview_extensions = ["png", "jpg", "jpeg", "webp"]
|
300 |
+
if shared.opts.samples_format not in preview_extensions:
|
301 |
+
preview_extensions.append(shared.opts.samples_format)
|
302 |
+
|
303 |
+
# file_name = os.path.basename(path)
|
304 |
+
# location = os.path.dirname(path)
|
305 |
+
# preview_path = location + "/preview/" + file_name
|
306 |
+
# potential_files = sum([[path + "." + ext, path + ".preview." + ext, preview_path + "." + ext, preview_path + ".preview." + ext] for ext in preview_extensions], [])
|
307 |
+
|
308 |
+
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
|
309 |
+
|
310 |
+
for file in potential_files:
|
311 |
+
if os.path.isfile(file):
|
312 |
+
return self.link_preview(file)
|
313 |
+
|
314 |
+
for file in potential_files:
|
315 |
+
if os.path.isfile(file):
|
316 |
+
return self.link_preview(file)
|
317 |
+
|
318 |
+
return None
|
319 |
+
|
320 |
+
def find_description(self, path):
|
321 |
+
"""
|
322 |
+
Find and read a description file for a given path (without extension).
|
323 |
+
"""
|
324 |
+
for file in [f"{path}.txt", f"{path}.description.txt"]:
|
325 |
+
try:
|
326 |
+
with open(file, "r", encoding="utf-8", errors="replace") as f:
|
327 |
+
return f.read()
|
328 |
+
except OSError:
|
329 |
+
pass
|
330 |
+
return None
|
331 |
+
|
332 |
+
def create_user_metadata_editor(self, ui, tabname):
|
333 |
+
return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self)
|
334 |
+
|
335 |
+
|
336 |
+
def initialize():
|
337 |
+
extra_pages.clear()
|
338 |
+
|
339 |
+
|
340 |
+
def register_default_pages():
|
341 |
+
from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
|
342 |
+
from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
|
343 |
+
from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
|
344 |
+
register_page(ExtraNetworksPageTextualInversion())
|
345 |
+
register_page(ExtraNetworksPageHypernetworks())
|
346 |
+
register_page(ExtraNetworksPageCheckpoints())
|
347 |
+
|
348 |
+
|
349 |
+
class ExtraNetworksUi:
|
350 |
+
def __init__(self):
|
351 |
+
self.pages = None
|
352 |
+
"""gradio HTML components related to extra networks' pages"""
|
353 |
+
|
354 |
+
self.page_contents = None
|
355 |
+
"""HTML content of the above; empty initially, filled when extra pages have to be shown"""
|
356 |
+
|
357 |
+
self.stored_extra_pages = None
|
358 |
+
|
359 |
+
self.button_save_preview = None
|
360 |
+
self.preview_target_filename = None
|
361 |
+
|
362 |
+
self.tabname = None
|
363 |
+
|
364 |
+
|
365 |
+
def pages_in_preferred_order(pages):
|
366 |
+
tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
|
367 |
+
|
368 |
+
def tab_name_score(name):
|
369 |
+
name = name.lower()
|
370 |
+
for i, possible_match in enumerate(tab_order):
|
371 |
+
if possible_match in name:
|
372 |
+
return i
|
373 |
+
|
374 |
+
return len(pages)
|
375 |
+
|
376 |
+
tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
|
377 |
+
|
378 |
+
return sorted(pages, key=lambda x: tab_scores[x.name])
|
379 |
+
|
380 |
+
|
381 |
+
def create_ui(container, button, tabname):
|
382 |
+
ui = ExtraNetworksUi()
|
383 |
+
ui.pages = []
|
384 |
+
ui.pages_contents = []
|
385 |
+
ui.user_metadata_editors = []
|
386 |
+
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
|
387 |
+
ui.tabname = tabname
|
388 |
+
|
389 |
+
with gr.Accordion("Extra Networks", open=True):
|
390 |
+
|
391 |
+
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
|
392 |
+
for page in ui.stored_extra_pages:
|
393 |
+
page_id = page.title.lower().replace(" ", "_")
|
394 |
+
with gr.Tab(page.title, id=page_id):
|
395 |
+
#elem_id = f"{tabname}_{page_id}_cards_html"
|
396 |
+
#page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
397 |
+
page_elem = gr.HTML(page.create_html(ui.tabname))
|
398 |
+
ui.pages.append(page_elem)
|
399 |
+
|
400 |
+
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
|
401 |
+
|
402 |
+
editor = page.create_user_metadata_editor(ui, tabname)
|
403 |
+
editor.create_ui()
|
404 |
+
ui.user_metadata_editors.append(editor)
|
405 |
+
|
406 |
+
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
|
407 |
+
#gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
|
408 |
+
#gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder")
|
409 |
+
button_refresh = ToolButton(value=refresh_symbol, elem_id=tabname+"_extra_refresh")
|
410 |
+
|
411 |
+
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
412 |
+
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
413 |
+
|
414 |
+
|
415 |
+
|
416 |
+
def toggle_visibility(is_visible):
|
417 |
+
is_visible = not is_visible
|
418 |
+
|
419 |
+
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
|
420 |
+
|
421 |
+
def fill_tabs(is_empty):
|
422 |
+
"""Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time."""
|
423 |
+
|
424 |
+
if not ui.pages_contents:
|
425 |
+
refresh()
|
426 |
+
|
427 |
+
if is_empty:
|
428 |
+
return True, *ui.pages_contents
|
429 |
+
|
430 |
+
return True, *[gr.update() for _ in ui.pages_contents]
|
431 |
+
|
432 |
+
state_visible = gr.State(value=False)
|
433 |
+
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False)
|
434 |
+
|
435 |
+
state_empty = gr.State(value=True)
|
436 |
+
button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False)
|
437 |
+
|
438 |
+
def refresh():
|
439 |
+
for pg in ui.stored_extra_pages:
|
440 |
+
pg.refresh()
|
441 |
+
|
442 |
+
ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
|
443 |
+
|
444 |
+
return ui.pages_contents
|
445 |
+
|
446 |
+
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
447 |
+
|
448 |
+
return ui
|
449 |
+
|
450 |
+
|
451 |
+
def path_is_parent(parent_path, child_path):
|
452 |
+
parent_path = os.path.abspath(parent_path)
|
453 |
+
child_path = os.path.abspath(child_path)
|
454 |
+
|
455 |
+
return child_path.startswith(parent_path)
|
456 |
+
|
457 |
+
|
458 |
+
def setup_ui(ui, gallery):
|
459 |
+
def save_preview(index, images, filename):
|
460 |
+
# this function is here for backwards compatibility and likely will be removed soon
|
461 |
+
|
462 |
+
if len(images) == 0:
|
463 |
+
print("There is no image in gallery to save as a preview.")
|
464 |
+
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
465 |
+
|
466 |
+
index = int(index)
|
467 |
+
index = 0 if index < 0 else index
|
468 |
+
index = len(images) - 1 if index >= len(images) else index
|
469 |
+
|
470 |
+
img_info = images[index if index >= 0 else 0]
|
471 |
+
image = image_from_url_text(img_info)
|
472 |
+
geninfo, items = read_info_from_image(image)
|
473 |
+
|
474 |
+
is_allowed = False
|
475 |
+
for extra_page in ui.stored_extra_pages:
|
476 |
+
if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()):
|
477 |
+
is_allowed = True
|
478 |
+
break
|
479 |
+
|
480 |
+
assert is_allowed, f'writing to {filename} is not allowed'
|
481 |
+
|
482 |
+
save_image_with_geninfo(image, geninfo, filename)
|
483 |
+
|
484 |
+
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
485 |
+
|
486 |
+
ui.button_save_preview.click(
|
487 |
+
fn=save_preview,
|
488 |
+
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
|
489 |
+
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
|
490 |
+
outputs=[*ui.pages]
|
491 |
+
)
|
492 |
+
|
493 |
+
for editor in ui.user_metadata_editors:
|
494 |
+
editor.setup_ui(gallery)
|
495 |
+
|
496 |
+
|
modules/ui_extra_networks_checkpoints.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
import os
|
3 |
+
|
4 |
+
from modules import shared, ui_extra_networks, sd_models
|
5 |
+
from modules.ui_extra_networks import quote_js
|
6 |
+
|
7 |
+
|
8 |
+
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__('Checkpoints')
|
11 |
+
|
12 |
+
def refresh(self):
|
13 |
+
shared.refresh_checkpoints()
|
14 |
+
|
15 |
+
def create_item(self, name, index=None):
|
16 |
+
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
|
17 |
+
path, ext = os.path.splitext(checkpoint.filename)
|
18 |
+
return {
|
19 |
+
"name": checkpoint.name_for_extra,
|
20 |
+
"filename": checkpoint.filename,
|
21 |
+
"preview": self.find_preview(path),
|
22 |
+
"description": self.find_description(path),
|
23 |
+
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
24 |
+
"onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
|
25 |
+
"local_preview": f"{path}.{shared.opts.samples_format}",
|
26 |
+
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
27 |
+
}
|
28 |
+
|
29 |
+
def list_items(self):
|
30 |
+
for index, name in enumerate(sd_models.checkpoints_list):
|
31 |
+
yield self.create_item(name, index)
|
32 |
+
|
33 |
+
def allowed_directories_for_previews(self):
|
34 |
+
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
35 |
+
|
modules/ui_extra_networks_hypernets.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from modules import shared, ui_extra_networks
|
4 |
+
from modules.ui_extra_networks import quote_js
|
5 |
+
|
6 |
+
|
7 |
+
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__('Hypernetworks')
|
10 |
+
|
11 |
+
def refresh(self):
|
12 |
+
shared.reload_hypernetworks()
|
13 |
+
|
14 |
+
def create_item(self, name, index=None):
|
15 |
+
full_path = shared.hypernetworks[name]
|
16 |
+
path, ext = os.path.splitext(full_path)
|
17 |
+
|
18 |
+
return {
|
19 |
+
"name": name,
|
20 |
+
"filename": full_path,
|
21 |
+
"preview": self.find_preview(path),
|
22 |
+
"description": self.find_description(path),
|
23 |
+
"search_term": self.search_terms_from_path(path),
|
24 |
+
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
|
25 |
+
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
26 |
+
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
|
27 |
+
}
|
28 |
+
|
29 |
+
def list_items(self):
|
30 |
+
for index, name in enumerate(shared.hypernetworks):
|
31 |
+
yield self.create_item(name, index)
|
32 |
+
|
33 |
+
def allowed_directories_for_previews(self):
|
34 |
+
return [shared.cmd_opts.hypernetwork_dir]
|
35 |
+
|
modules/ui_extra_networks_textual_inversion.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from modules import ui_extra_networks, sd_hijack, shared
|
4 |
+
from modules.ui_extra_networks import quote_js
|
5 |
+
|
6 |
+
|
7 |
+
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__('Textual Inversion')
|
10 |
+
self.allow_negative_prompt = True
|
11 |
+
|
12 |
+
def refresh(self):
|
13 |
+
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
14 |
+
|
15 |
+
def create_item(self, name, index=None):
|
16 |
+
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
|
17 |
+
|
18 |
+
path, ext = os.path.splitext(embedding.filename)
|
19 |
+
return {
|
20 |
+
"name": name,
|
21 |
+
"filename": embedding.filename,
|
22 |
+
"preview": self.find_preview(path),
|
23 |
+
"description": self.find_description(path),
|
24 |
+
"search_term": self.search_terms_from_path(embedding.filename),
|
25 |
+
"prompt": quote_js(embedding.name),
|
26 |
+
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
27 |
+
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
|
28 |
+
}
|
29 |
+
|
30 |
+
def list_items(self):
|
31 |
+
for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
|
32 |
+
yield self.create_item(name, index)
|
33 |
+
|
34 |
+
def allowed_directories_for_previews(self):
|
35 |
+
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|
modules/ui_extra_networks_user_metadata.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import html
|
3 |
+
import json
|
4 |
+
import os.path
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from modules import generation_parameters_copypaste, images, sysinfo, errors
|
9 |
+
|
10 |
+
|
11 |
+
class UserMetadataEditor:
|
12 |
+
|
13 |
+
def __init__(self, ui, tabname, page):
|
14 |
+
self.ui = ui
|
15 |
+
self.tabname = tabname
|
16 |
+
self.page = page
|
17 |
+
self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata"
|
18 |
+
|
19 |
+
self.box = None
|
20 |
+
|
21 |
+
self.edit_name_input = None
|
22 |
+
self.button_edit = None
|
23 |
+
|
24 |
+
self.edit_name = None
|
25 |
+
self.edit_description = None
|
26 |
+
self.edit_notes = None
|
27 |
+
self.html_filedata = None
|
28 |
+
self.html_preview = None
|
29 |
+
self.html_status = None
|
30 |
+
|
31 |
+
self.button_cancel = None
|
32 |
+
self.button_replace_preview = None
|
33 |
+
self.button_save = None
|
34 |
+
|
35 |
+
def get_user_metadata(self, name):
|
36 |
+
item = self.page.items.get(name, {})
|
37 |
+
|
38 |
+
user_metadata = item.get('user_metadata', None)
|
39 |
+
if user_metadata is None:
|
40 |
+
user_metadata = {}
|
41 |
+
item['user_metadata'] = user_metadata
|
42 |
+
|
43 |
+
return user_metadata
|
44 |
+
|
45 |
+
def create_extra_default_items_in_left_column(self):
|
46 |
+
pass
|
47 |
+
|
48 |
+
def create_default_editor_elems(self):
|
49 |
+
with gr.Row():
|
50 |
+
with gr.Column(scale=2):
|
51 |
+
self.edit_name = gr.HTML(elem_classes="extra-network-name")
|
52 |
+
self.edit_description = gr.Textbox(label="Description", lines=4)
|
53 |
+
self.html_filedata = gr.HTML()
|
54 |
+
|
55 |
+
self.create_extra_default_items_in_left_column()
|
56 |
+
|
57 |
+
with gr.Column(scale=1, min_width=0):
|
58 |
+
self.html_preview = gr.HTML()
|
59 |
+
|
60 |
+
def create_default_buttons(self):
|
61 |
+
|
62 |
+
with gr.Row(elem_classes="edit-user-metadata-buttons"):
|
63 |
+
self.button_cancel = gr.Button('Cancel')
|
64 |
+
self.button_replace_preview = gr.Button('Replace preview', variant='primary')
|
65 |
+
self.button_save = gr.Button('Save', variant='primary')
|
66 |
+
|
67 |
+
self.html_status = gr.HTML(elem_classes="edit-user-metadata-status")
|
68 |
+
|
69 |
+
self.button_cancel.click(fn=None, _js="closePopup")
|
70 |
+
|
71 |
+
def get_card_html(self, name):
|
72 |
+
item = self.page.items.get(name, {})
|
73 |
+
|
74 |
+
preview_url = item.get("preview", None)
|
75 |
+
|
76 |
+
if not preview_url:
|
77 |
+
filename, _ = os.path.splitext(item["filename"])
|
78 |
+
preview_url = self.page.find_preview(filename)
|
79 |
+
item["preview"] = preview_url
|
80 |
+
|
81 |
+
if preview_url:
|
82 |
+
preview = f'''
|
83 |
+
<div class='card standalone-card-preview'>
|
84 |
+
<img src="{html.escape(preview_url)}" class="preview">
|
85 |
+
</div>
|
86 |
+
'''
|
87 |
+
else:
|
88 |
+
preview = "<div class='card standalone-card-preview'></div>"
|
89 |
+
|
90 |
+
return preview
|
91 |
+
|
92 |
+
def get_metadata_table(self, name):
|
93 |
+
item = self.page.items.get(name, {})
|
94 |
+
try:
|
95 |
+
filename = item["filename"]
|
96 |
+
|
97 |
+
stats = os.stat(filename)
|
98 |
+
params = [
|
99 |
+
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
|
100 |
+
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
|
101 |
+
]
|
102 |
+
|
103 |
+
return params
|
104 |
+
except Exception as e:
|
105 |
+
errors.display(e, f"reading info for {name}")
|
106 |
+
return []
|
107 |
+
|
108 |
+
def put_values_into_components(self, name):
|
109 |
+
user_metadata = self.get_user_metadata(name)
|
110 |
+
|
111 |
+
try:
|
112 |
+
params = self.get_metadata_table(name)
|
113 |
+
except Exception as e:
|
114 |
+
errors.display(e, f"reading metadata info for {name}")
|
115 |
+
params = []
|
116 |
+
|
117 |
+
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'
|
118 |
+
|
119 |
+
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
|
120 |
+
|
121 |
+
def write_user_metadata(self, name, metadata):
|
122 |
+
item = self.page.items.get(name, {})
|
123 |
+
filename = item.get("filename", None)
|
124 |
+
basename, ext = os.path.splitext(filename)
|
125 |
+
|
126 |
+
with open(basename + '.json', "w", encoding="utf8") as file:
|
127 |
+
json.dump(metadata, file)
|
128 |
+
|
129 |
+
def save_user_metadata(self, name, desc, notes):
|
130 |
+
user_metadata = self.get_user_metadata(name)
|
131 |
+
user_metadata["description"] = desc
|
132 |
+
user_metadata["notes"] = notes
|
133 |
+
|
134 |
+
self.write_user_metadata(name, user_metadata)
|
135 |
+
|
136 |
+
def setup_save_handler(self, button, func, components):
|
137 |
+
button\
|
138 |
+
.click(fn=func, inputs=[self.edit_name_input, *components], outputs=[])\
|
139 |
+
.then(fn=None, _js="function(name){closePopup(); extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", inputs=[self.edit_name_input], outputs=[])
|
140 |
+
|
141 |
+
def create_editor(self):
|
142 |
+
self.create_default_editor_elems()
|
143 |
+
|
144 |
+
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
145 |
+
|
146 |
+
self.create_default_buttons()
|
147 |
+
|
148 |
+
self.button_edit\
|
149 |
+
.click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview, self.edit_notes])\
|
150 |
+
.then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
|
151 |
+
|
152 |
+
self.setup_save_handler(self.button_save, self.save_user_metadata, [self.edit_description, self.edit_notes])
|
153 |
+
|
154 |
+
def create_ui(self):
|
155 |
+
with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box:
|
156 |
+
self.box = box
|
157 |
+
|
158 |
+
self.edit_name_input = gr.Textbox("Edit user metadata card id", visible=False, elem_id=f"{self.id_part}_name")
|
159 |
+
self.button_edit = gr.Button("Edit user metadata", visible=False, elem_id=f"{self.id_part}_button")
|
160 |
+
|
161 |
+
self.create_editor()
|
162 |
+
|
163 |
+
def save_preview(self, index, gallery, name):
|
164 |
+
if len(gallery) == 0:
|
165 |
+
return self.get_card_html(name), "There is no image in gallery to save as a preview."
|
166 |
+
|
167 |
+
item = self.page.items.get(name, {})
|
168 |
+
|
169 |
+
index = int(index)
|
170 |
+
index = 0 if index < 0 else index
|
171 |
+
index = len(gallery) - 1 if index >= len(gallery) else index
|
172 |
+
|
173 |
+
img_info = gallery[index if index >= 0 else 0]
|
174 |
+
image = generation_parameters_copypaste.image_from_url_text(img_info)
|
175 |
+
geninfo, items = images.read_info_from_image(image)
|
176 |
+
|
177 |
+
images.save_image_with_geninfo(image, geninfo, item["local_preview"])
|
178 |
+
|
179 |
+
return self.get_card_html(name), ''
|
180 |
+
|
181 |
+
def setup_ui(self, gallery):
|
182 |
+
self.button_replace_preview.click(
|
183 |
+
fn=self.save_preview,
|
184 |
+
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
|
185 |
+
inputs=[self.edit_name_input, gallery, self.edit_name_input],
|
186 |
+
outputs=[self.html_preview, self.html_status]
|
187 |
+
).then(
|
188 |
+
fn=None,
|
189 |
+
_js="function(name){extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}",
|
190 |
+
inputs=[self.edit_name_input],
|
191 |
+
outputs=[]
|
192 |
+
)
|
193 |
+
|
194 |
+
|
195 |
+
|
modules/ui_gradio_extensions.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from modules import localization, shared, scripts
|
5 |
+
from modules.paths import script_path, data_path
|
6 |
+
|
7 |
+
|
8 |
+
def webpath(fn):
|
9 |
+
if fn.startswith(script_path):
|
10 |
+
web_path = os.path.relpath(fn, script_path).replace('\\', '/')
|
11 |
+
else:
|
12 |
+
web_path = os.path.abspath(fn)
|
13 |
+
|
14 |
+
return f'file={web_path}?{os.path.getmtime(fn)}'
|
15 |
+
|
16 |
+
|
17 |
+
def javascript_html():
|
18 |
+
# Ensure localization is in `window` before scripts
|
19 |
+
head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
|
20 |
+
|
21 |
+
script_js = os.path.join(script_path, "script.js")
|
22 |
+
head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
|
23 |
+
|
24 |
+
for script in scripts.list_scripts("javascript", ".js"):
|
25 |
+
head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
|
26 |
+
|
27 |
+
for script in scripts.list_scripts("javascript", ".mjs"):
|
28 |
+
head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
|
29 |
+
|
30 |
+
if shared.cmd_opts.theme:
|
31 |
+
head += f'<script type="text/javascript">set_theme(\"{shared.cmd_opts.theme}\");</script>\n'
|
32 |
+
|
33 |
+
return head
|
34 |
+
|
35 |
+
|
36 |
+
def css_html():
|
37 |
+
head = ""
|
38 |
+
|
39 |
+
def stylesheet(fn):
|
40 |
+
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
|
41 |
+
|
42 |
+
for cssfile in scripts.list_files_with_name("style.css"):
|
43 |
+
if not os.path.isfile(cssfile):
|
44 |
+
continue
|
45 |
+
|
46 |
+
head += stylesheet(cssfile)
|
47 |
+
|
48 |
+
if os.path.exists(os.path.join(data_path, "user.css")):
|
49 |
+
head += stylesheet(os.path.join(data_path, "user.css"))
|
50 |
+
|
51 |
+
return head
|
52 |
+
|
53 |
+
|
54 |
+
def reload_javascript():
|
55 |
+
js = javascript_html()
|
56 |
+
css = css_html()
|
57 |
+
|
58 |
+
def template_response(*args, **kwargs):
|
59 |
+
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
|
60 |
+
res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
|
61 |
+
res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
|
62 |
+
res.init_headers()
|
63 |
+
return res
|
64 |
+
|
65 |
+
gr.routes.templates.TemplateResponse = template_response
|
66 |
+
|
67 |
+
|
68 |
+
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
|
69 |
+
shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
|
modules/ui_loadsave.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
from modules import errors
|
7 |
+
from modules.ui_components import ToolButton
|
8 |
+
|
9 |
+
|
10 |
+
class UiLoadsave:
|
11 |
+
"""allows saving and restorig default values for gradio components"""
|
12 |
+
|
13 |
+
def __init__(self, filename):
|
14 |
+
self.filename = filename
|
15 |
+
self.ui_settings = {}
|
16 |
+
self.component_mapping = {}
|
17 |
+
self.error_loading = False
|
18 |
+
self.finalized_ui = False
|
19 |
+
|
20 |
+
self.ui_defaults_view = None
|
21 |
+
self.ui_defaults_apply = None
|
22 |
+
self.ui_defaults_review = None
|
23 |
+
|
24 |
+
try:
|
25 |
+
if os.path.exists(self.filename):
|
26 |
+
self.ui_settings = self.read_from_file()
|
27 |
+
except Exception as e:
|
28 |
+
self.error_loading = True
|
29 |
+
errors.display(e, "loading settings")
|
30 |
+
|
31 |
+
def add_component(self, path, x):
|
32 |
+
"""adds component to the registry of tracked components"""
|
33 |
+
|
34 |
+
assert not self.finalized_ui
|
35 |
+
|
36 |
+
def apply_field(obj, field, condition=None, init_field=None):
|
37 |
+
key = f"{path}/{field}"
|
38 |
+
|
39 |
+
if getattr(obj, 'custom_script_source', None) is not None:
|
40 |
+
key = f"customscript/{obj.custom_script_source}/{key}"
|
41 |
+
|
42 |
+
if getattr(obj, 'do_not_save_to_config', False):
|
43 |
+
return
|
44 |
+
|
45 |
+
saved_value = self.ui_settings.get(key, None)
|
46 |
+
if saved_value is None:
|
47 |
+
self.ui_settings[key] = getattr(obj, field)
|
48 |
+
elif condition and not condition(saved_value):
|
49 |
+
pass
|
50 |
+
else:
|
51 |
+
setattr(obj, field, saved_value)
|
52 |
+
if init_field is not None:
|
53 |
+
init_field(saved_value)
|
54 |
+
|
55 |
+
if field == 'value' and key not in self.component_mapping:
|
56 |
+
self.component_mapping[key] = x
|
57 |
+
|
58 |
+
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
|
59 |
+
apply_field(x, 'visible')
|
60 |
+
|
61 |
+
if type(x) == gr.Slider:
|
62 |
+
apply_field(x, 'value')
|
63 |
+
apply_field(x, 'minimum')
|
64 |
+
apply_field(x, 'maximum')
|
65 |
+
apply_field(x, 'step')
|
66 |
+
|
67 |
+
if type(x) == gr.Radio:
|
68 |
+
apply_field(x, 'value', lambda val: val in x.choices)
|
69 |
+
|
70 |
+
if type(x) == gr.Checkbox:
|
71 |
+
apply_field(x, 'value')
|
72 |
+
|
73 |
+
if type(x) == gr.Textbox:
|
74 |
+
apply_field(x, 'value')
|
75 |
+
|
76 |
+
if type(x) == gr.Number:
|
77 |
+
apply_field(x, 'value')
|
78 |
+
|
79 |
+
if type(x) == gr.Dropdown:
|
80 |
+
def check_dropdown(val):
|
81 |
+
if getattr(x, 'multiselect', False):
|
82 |
+
return all(value in x.choices for value in val)
|
83 |
+
else:
|
84 |
+
return val in x.choices
|
85 |
+
|
86 |
+
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
87 |
+
|
88 |
+
def check_tab_id(tab_id):
|
89 |
+
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
|
90 |
+
if type(tab_id) == str:
|
91 |
+
tab_ids = [t.id for t in tab_items]
|
92 |
+
return tab_id in tab_ids
|
93 |
+
elif type(tab_id) == int:
|
94 |
+
return 0 <= tab_id < len(tab_items)
|
95 |
+
else:
|
96 |
+
return False
|
97 |
+
|
98 |
+
if type(x) == gr.Tabs:
|
99 |
+
apply_field(x, 'selected', check_tab_id)
|
100 |
+
|
101 |
+
def add_block(self, x, path=""):
|
102 |
+
"""adds all components inside a gradio block x to the registry of tracked components"""
|
103 |
+
|
104 |
+
if hasattr(x, 'children'):
|
105 |
+
if isinstance(x, gr.Tabs) and x.elem_id is not None:
|
106 |
+
# Tabs element can't have a label, have to use elem_id instead
|
107 |
+
self.add_component(f"{path}/Tabs@{x.elem_id}", x)
|
108 |
+
for c in x.children:
|
109 |
+
self.add_block(c, path)
|
110 |
+
elif x.label is not None:
|
111 |
+
self.add_component(f"{path}/{x.label}", x)
|
112 |
+
elif isinstance(x, gr.Button) and x.value is not None:
|
113 |
+
self.add_component(f"{path}/{x.value}", x)
|
114 |
+
|
115 |
+
def read_from_file(self):
|
116 |
+
with open(self.filename, "r", encoding="utf8") as file:
|
117 |
+
return json.load(file)
|
118 |
+
|
119 |
+
def write_to_file(self, current_ui_settings):
|
120 |
+
with open(self.filename, "w", encoding="utf8") as file:
|
121 |
+
json.dump(current_ui_settings, file, indent=4)
|
122 |
+
|
123 |
+
def dump_defaults(self):
|
124 |
+
"""saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
|
125 |
+
|
126 |
+
if self.error_loading and os.path.exists(self.filename):
|
127 |
+
return
|
128 |
+
|
129 |
+
self.write_to_file(self.ui_settings)
|
130 |
+
|
131 |
+
def iter_changes(self, current_ui_settings, values):
|
132 |
+
"""
|
133 |
+
given a dictionary with defaults from a file and current values from gradio elements, returns
|
134 |
+
an iterator over tuples of values that are not the same between the file and the current;
|
135 |
+
tuple contents are: path, old value, new value
|
136 |
+
"""
|
137 |
+
|
138 |
+
for (path, component), new_value in zip(self.component_mapping.items(), values):
|
139 |
+
old_value = current_ui_settings.get(path)
|
140 |
+
|
141 |
+
choices = getattr(component, 'choices', None)
|
142 |
+
if isinstance(new_value, int) and choices:
|
143 |
+
if new_value >= len(choices):
|
144 |
+
continue
|
145 |
+
|
146 |
+
new_value = choices[new_value]
|
147 |
+
|
148 |
+
if new_value == old_value:
|
149 |
+
continue
|
150 |
+
|
151 |
+
if old_value is None and new_value == '' or new_value == []:
|
152 |
+
continue
|
153 |
+
|
154 |
+
yield path, old_value, new_value
|
155 |
+
|
156 |
+
def ui_view(self, *values):
|
157 |
+
text = ["<table><thead><tr><th>Path</th><th>Old value</th><th>New value</th></thead><tbody>"]
|
158 |
+
|
159 |
+
for path, old_value, new_value in self.iter_changes(self.read_from_file(), values):
|
160 |
+
if old_value is None:
|
161 |
+
old_value = "<span class='ui-defaults-none'>None</span>"
|
162 |
+
|
163 |
+
text.append(f"<tr><td>{path}</td><td>{old_value}</td><td>{new_value}</td></tr>")
|
164 |
+
|
165 |
+
if len(text) == 1:
|
166 |
+
text.append("<tr><td colspan=3>No changes</td></tr>")
|
167 |
+
|
168 |
+
text.append("</tbody>")
|
169 |
+
return "".join(text)
|
170 |
+
|
171 |
+
def ui_apply(self, *values):
|
172 |
+
num_changed = 0
|
173 |
+
|
174 |
+
current_ui_settings = self.read_from_file()
|
175 |
+
|
176 |
+
for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values):
|
177 |
+
num_changed += 1
|
178 |
+
current_ui_settings[path] = new_value
|
179 |
+
|
180 |
+
if num_changed == 0:
|
181 |
+
return "No changes."
|
182 |
+
|
183 |
+
self.write_to_file(current_ui_settings)
|
184 |
+
|
185 |
+
return f"Wrote {num_changed} changes."
|
186 |
+
|
187 |
+
def create_ui(self):
|
188 |
+
"""creates ui elements for editing defaults UI, without adding any logic to them"""
|
189 |
+
|
190 |
+
gr.HTML(
|
191 |
+
f"This page allows you to change default values in UI elements on other tabs.<br />"
|
192 |
+
f"Make your changes, press 'View changes' to review the changed default values,<br />"
|
193 |
+
f"then press 'Apply' to write them to {self.filename}.<br />"
|
194 |
+
f"New defaults will apply after you restart the UI.<br />"
|
195 |
+
)
|
196 |
+
|
197 |
+
with gr.Row():
|
198 |
+
self.ui_defaults_view = gr.Button(value='View changes', elem_id="ui_defaults_view", variant="secondary")
|
199 |
+
self.ui_defaults_apply = gr.Button(value='Apply', elem_id="ui_defaults_apply", variant="primary")
|
200 |
+
|
201 |
+
self.ui_defaults_review = gr.HTML("")
|
202 |
+
|
203 |
+
def setup_ui(self):
|
204 |
+
"""adds logic to elements created with create_ui; all add_block class must be made before this"""
|
205 |
+
|
206 |
+
assert not self.finalized_ui
|
207 |
+
self.finalized_ui = True
|
208 |
+
|
209 |
+
self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
|
210 |
+
self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
|
modules/ui_postprocessing.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from modules import scripts, shared, ui_common, postprocessing, call_queue
|
3 |
+
import modules.generation_parameters_copypaste as parameters_copypaste
|
4 |
+
|
5 |
+
def create_ui():
|
6 |
+
tab_index = gr.State(value=0)
|
7 |
+
gr.Row(elem_id="extras_2img_prompt_image", visible=False)
|
8 |
+
with gr.Row():
|
9 |
+
with gr.Column(elem_id="extras_2img_results"):
|
10 |
+
result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras_2img", shared.opts.outdir_extras_samples)
|
11 |
+
gr.Row(elem_id="extras_2img_splitter")
|
12 |
+
with gr.Column(variant='panel', elem_id="extras_2img_settings"):
|
13 |
+
submit = gr.Button('Upscale', elem_id="extras_generate", variant='primary')
|
14 |
+
with gr.Column(elem_id="extras_2img_settings_scroll"):
|
15 |
+
with gr.Accordion("Image Source", elem_id="extras_accordion", open=True):
|
16 |
+
with gr.Tabs(elem_id="mode_extras"):
|
17 |
+
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
|
18 |
+
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
|
19 |
+
|
20 |
+
with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
|
21 |
+
image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
|
22 |
+
|
23 |
+
with gr.TabItem('Batch from Directory', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
|
24 |
+
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
|
25 |
+
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
|
26 |
+
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
|
27 |
+
|
28 |
+
script_inputs = scripts.scripts_postproc.setup_ui()
|
29 |
+
|
30 |
+
tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
|
31 |
+
tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
|
32 |
+
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
|
33 |
+
|
34 |
+
submit.click(
|
35 |
+
fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
|
36 |
+
inputs=[
|
37 |
+
tab_index,
|
38 |
+
extras_image,
|
39 |
+
image_batch,
|
40 |
+
extras_batch_input_dir,
|
41 |
+
extras_batch_output_dir,
|
42 |
+
show_extras_results,
|
43 |
+
*script_inputs
|
44 |
+
],
|
45 |
+
outputs=[
|
46 |
+
result_images,
|
47 |
+
html_info_x,
|
48 |
+
html_info,
|
49 |
+
]
|
50 |
+
)
|
51 |
+
|
52 |
+
parameters_copypaste.add_paste_fields("extras", extras_image, None)
|
53 |
+
|
54 |
+
extras_image.change(
|
55 |
+
fn=scripts.scripts_postproc.image_changed,
|
56 |
+
inputs=[], outputs=[]
|
57 |
+
)
|
modules/ui_settings.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
|
4 |
+
from modules.call_queue import wrap_gradio_call
|
5 |
+
from modules.shared import opts
|
6 |
+
from modules.ui_components import FormRow
|
7 |
+
from modules.ui_gradio_extensions import reload_javascript
|
8 |
+
|
9 |
+
|
10 |
+
def get_value_for_setting(key):
|
11 |
+
value = getattr(opts, key)
|
12 |
+
|
13 |
+
info = opts.data_labels[key]
|
14 |
+
args = info.component_args() if callable(info.component_args) else info.component_args or {}
|
15 |
+
args = {k: v for k, v in args.items() if k not in {'precision'}}
|
16 |
+
|
17 |
+
return gr.update(value=value, **args)
|
18 |
+
|
19 |
+
|
20 |
+
def create_setting_component(key, is_quicksettings=False):
|
21 |
+
def fun():
|
22 |
+
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
23 |
+
|
24 |
+
info = opts.data_labels[key]
|
25 |
+
t = type(info.default)
|
26 |
+
|
27 |
+
args = info.component_args() if callable(info.component_args) else info.component_args
|
28 |
+
|
29 |
+
if info.component is not None:
|
30 |
+
comp = info.component
|
31 |
+
elif t == str:
|
32 |
+
comp = gr.Textbox
|
33 |
+
elif t == int:
|
34 |
+
comp = gr.Number
|
35 |
+
elif t == bool:
|
36 |
+
comp = gr.Checkbox
|
37 |
+
else:
|
38 |
+
raise Exception(f'bad options item type: {t} for key {key}')
|
39 |
+
|
40 |
+
elem_id = f"setting_{key}"
|
41 |
+
|
42 |
+
if info.refresh is not None:
|
43 |
+
if is_quicksettings:
|
44 |
+
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
45 |
+
ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
46 |
+
else:
|
47 |
+
with FormRow():
|
48 |
+
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
49 |
+
ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
50 |
+
else:
|
51 |
+
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
52 |
+
|
53 |
+
return res
|
54 |
+
|
55 |
+
|
56 |
+
class UiSettings:
|
57 |
+
submit = None
|
58 |
+
result = None
|
59 |
+
interface = None
|
60 |
+
components = None
|
61 |
+
component_dict = None
|
62 |
+
dummy_component = None
|
63 |
+
quicksettings_list = None
|
64 |
+
quicksettings_names = None
|
65 |
+
text_settings = None
|
66 |
+
|
67 |
+
def run_settings(self, *args):
|
68 |
+
changed = []
|
69 |
+
|
70 |
+
for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
|
71 |
+
assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
|
72 |
+
|
73 |
+
for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
|
74 |
+
if comp == self.dummy_component:
|
75 |
+
continue
|
76 |
+
|
77 |
+
if opts.set(key, value):
|
78 |
+
changed.append(key)
|
79 |
+
|
80 |
+
try:
|
81 |
+
opts.save(shared.config_filename)
|
82 |
+
except RuntimeError:
|
83 |
+
return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
|
84 |
+
return opts.dumpjson(), f'{len(changed)} settings changed{": " if changed else ""}{", ".join(changed)}.'
|
85 |
+
|
86 |
+
def run_settings_single(self, value, key):
|
87 |
+
if not opts.same_type(value, opts.data_labels[key].default):
|
88 |
+
return gr.update(visible=True), opts.dumpjson()
|
89 |
+
|
90 |
+
if not opts.set(key, value):
|
91 |
+
return gr.update(value=getattr(opts, key)), opts.dumpjson()
|
92 |
+
|
93 |
+
opts.save(shared.config_filename)
|
94 |
+
|
95 |
+
return get_value_for_setting(key), opts.dumpjson()
|
96 |
+
|
97 |
+
def create_ui(self, loadsave, dummy_component):
|
98 |
+
self.components = []
|
99 |
+
self.component_dict = {}
|
100 |
+
self.dummy_component = dummy_component
|
101 |
+
|
102 |
+
shared.settings_components = self.component_dict
|
103 |
+
|
104 |
+
script_callbacks.ui_settings_callback()
|
105 |
+
opts.reorder()
|
106 |
+
|
107 |
+
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
108 |
+
with gr.Row():
|
109 |
+
with gr.Column(scale=6):
|
110 |
+
self.submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
|
111 |
+
with gr.Column():
|
112 |
+
restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
|
113 |
+
|
114 |
+
self.result = gr.HTML(elem_id="settings_result")
|
115 |
+
|
116 |
+
self.quicksettings_names = opts.quicksettings_list
|
117 |
+
self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'}
|
118 |
+
|
119 |
+
self.quicksettings_list = []
|
120 |
+
|
121 |
+
previous_section = None
|
122 |
+
current_tab = None
|
123 |
+
current_row = None
|
124 |
+
with gr.Tabs(elem_id="settings"):
|
125 |
+
for i, (k, item) in enumerate(opts.data_labels.items()):
|
126 |
+
section_must_be_skipped = item.section[0] is None
|
127 |
+
|
128 |
+
if previous_section != item.section and not section_must_be_skipped:
|
129 |
+
elem_id, text = item.section
|
130 |
+
|
131 |
+
if current_tab is not None:
|
132 |
+
current_row.__exit__()
|
133 |
+
current_tab.__exit__()
|
134 |
+
|
135 |
+
gr.Group()
|
136 |
+
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
|
137 |
+
current_tab.__enter__()
|
138 |
+
current_row = gr.Column(variant='compact')
|
139 |
+
current_row.__enter__()
|
140 |
+
|
141 |
+
previous_section = item.section
|
142 |
+
|
143 |
+
if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings:
|
144 |
+
self.quicksettings_list.append((i, k, item))
|
145 |
+
self.components.append(dummy_component)
|
146 |
+
elif section_must_be_skipped:
|
147 |
+
self.components.append(dummy_component)
|
148 |
+
else:
|
149 |
+
component = create_setting_component(k)
|
150 |
+
self.component_dict[k] = component
|
151 |
+
self.components.append(component)
|
152 |
+
|
153 |
+
if current_tab is not None:
|
154 |
+
current_row.__exit__()
|
155 |
+
current_tab.__exit__()
|
156 |
+
|
157 |
+
with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
|
158 |
+
loadsave.create_ui()
|
159 |
+
|
160 |
+
with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
|
161 |
+
gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo">(or open as text in a new page)</a>', elem_id="sysinfo_download")
|
162 |
+
|
163 |
+
with gr.Row():
|
164 |
+
with gr.Column(scale=1):
|
165 |
+
sysinfo_check_file = gr.File(label="Check system info for validity", type='binary')
|
166 |
+
with gr.Column(scale=1):
|
167 |
+
sysinfo_check_output = gr.HTML("", elem_id="sysinfo_validity")
|
168 |
+
with gr.Column(scale=100):
|
169 |
+
pass
|
170 |
+
|
171 |
+
with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
|
172 |
+
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
173 |
+
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
174 |
+
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
175 |
+
with gr.Row():
|
176 |
+
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
177 |
+
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
178 |
+
|
179 |
+
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
|
180 |
+
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
181 |
+
|
182 |
+
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
183 |
+
|
184 |
+
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
185 |
+
|
186 |
+
unload_sd_model.click(
|
187 |
+
fn=sd_models.unload_model_weights,
|
188 |
+
inputs=[],
|
189 |
+
outputs=[]
|
190 |
+
)
|
191 |
+
|
192 |
+
reload_sd_model.click(
|
193 |
+
fn=sd_models.reload_model_weights,
|
194 |
+
inputs=[],
|
195 |
+
outputs=[]
|
196 |
+
)
|
197 |
+
|
198 |
+
request_notifications.click(
|
199 |
+
fn=lambda: None,
|
200 |
+
inputs=[],
|
201 |
+
outputs=[],
|
202 |
+
_js='function(){}'
|
203 |
+
)
|
204 |
+
|
205 |
+
download_localization.click(
|
206 |
+
fn=lambda: None,
|
207 |
+
inputs=[],
|
208 |
+
outputs=[],
|
209 |
+
_js='download_localization'
|
210 |
+
)
|
211 |
+
|
212 |
+
def reload_scripts():
|
213 |
+
scripts.reload_script_body_only()
|
214 |
+
reload_javascript() # need to refresh the html page
|
215 |
+
|
216 |
+
reload_script_bodies.click(
|
217 |
+
fn=reload_scripts,
|
218 |
+
inputs=[],
|
219 |
+
outputs=[]
|
220 |
+
)
|
221 |
+
|
222 |
+
restart_gradio.click(
|
223 |
+
fn=shared.state.request_restart,
|
224 |
+
_js='restart_reload',
|
225 |
+
inputs=[],
|
226 |
+
outputs=[],
|
227 |
+
)
|
228 |
+
|
229 |
+
def check_file(x):
|
230 |
+
if x is None:
|
231 |
+
return ''
|
232 |
+
|
233 |
+
if sysinfo.check(x.decode('utf8', errors='ignore')):
|
234 |
+
return 'Valid'
|
235 |
+
|
236 |
+
return 'Invalid'
|
237 |
+
|
238 |
+
sysinfo_check_file.change(
|
239 |
+
fn=check_file,
|
240 |
+
inputs=[sysinfo_check_file],
|
241 |
+
outputs=[sysinfo_check_output],
|
242 |
+
)
|
243 |
+
|
244 |
+
self.interface = settings_interface
|
245 |
+
|
246 |
+
def add_quicksettings(self):
|
247 |
+
with gr.Row(elem_id="quicksettings", variant="compact"):
|
248 |
+
for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])):
|
249 |
+
component = create_setting_component(k, is_quicksettings=True)
|
250 |
+
self.component_dict[k] = component
|
251 |
+
|
252 |
+
def add_functionality(self, demo):
|
253 |
+
self.submit.click(
|
254 |
+
fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]),
|
255 |
+
inputs=self.components,
|
256 |
+
outputs=[self.text_settings, self.result],
|
257 |
+
)
|
258 |
+
|
259 |
+
for _i, k, _item in self.quicksettings_list:
|
260 |
+
component = self.component_dict[k]
|
261 |
+
info = opts.data_labels[k]
|
262 |
+
|
263 |
+
if isinstance(component, gr.Textbox):
|
264 |
+
methods = [component.submit, component.blur]
|
265 |
+
elif hasattr(component, 'release'):
|
266 |
+
methods = [component.release]
|
267 |
+
else:
|
268 |
+
methods = [component.change]
|
269 |
+
|
270 |
+
for method in methods:
|
271 |
+
method(
|
272 |
+
fn=lambda value, k=k: self.run_settings_single(value, key=k),
|
273 |
+
inputs=[component],
|
274 |
+
outputs=[component, self.text_settings],
|
275 |
+
show_progress=info.refresh is not None,
|
276 |
+
)
|
277 |
+
|
278 |
+
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
279 |
+
button_set_checkpoint.click(
|
280 |
+
fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),
|
281 |
+
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
282 |
+
inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],
|
283 |
+
outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],
|
284 |
+
)
|
285 |
+
|
286 |
+
component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]
|
287 |
+
|
288 |
+
def get_settings_values():
|
289 |
+
return [get_value_for_setting(key) for key in component_keys]
|
290 |
+
|
291 |
+
demo.load(
|
292 |
+
fn=get_settings_values,
|
293 |
+
inputs=[],
|
294 |
+
outputs=[self.component_dict[k] for k in component_keys],
|
295 |
+
queue=False,
|
296 |
+
)
|
modules/ui_tempdir.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
from collections import namedtuple
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import gradio.components
|
7 |
+
|
8 |
+
from PIL import PngImagePlugin
|
9 |
+
|
10 |
+
from modules import shared
|
11 |
+
|
12 |
+
|
13 |
+
Savedfile = namedtuple("Savedfile", ["name"])
|
14 |
+
|
15 |
+
|
16 |
+
def register_tmp_file(gradio, filename):
|
17 |
+
if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
|
18 |
+
gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
|
19 |
+
|
20 |
+
if hasattr(gradio, 'temp_dirs'): # gradio 3.9
|
21 |
+
gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
|
22 |
+
|
23 |
+
|
24 |
+
def check_tmp_file(gradio, filename):
|
25 |
+
if hasattr(gradio, 'temp_file_sets'):
|
26 |
+
return any(filename in fileset for fileset in gradio.temp_file_sets)
|
27 |
+
|
28 |
+
if hasattr(gradio, 'temp_dirs'):
|
29 |
+
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
|
30 |
+
|
31 |
+
return False
|
32 |
+
|
33 |
+
|
34 |
+
def save_pil_to_file(self, pil_image, dir=None, format="png"):
|
35 |
+
already_saved_as = getattr(pil_image, 'already_saved_as', None)
|
36 |
+
if already_saved_as and os.path.isfile(already_saved_as):
|
37 |
+
register_tmp_file(shared.demo, already_saved_as)
|
38 |
+
filename = already_saved_as
|
39 |
+
|
40 |
+
if not shared.opts.save_images_add_number:
|
41 |
+
filename += f'?{os.path.getmtime(already_saved_as)}'
|
42 |
+
|
43 |
+
return filename
|
44 |
+
|
45 |
+
if shared.opts.temp_dir != "":
|
46 |
+
dir = shared.opts.temp_dir
|
47 |
+
|
48 |
+
use_metadata = False
|
49 |
+
metadata = PngImagePlugin.PngInfo()
|
50 |
+
for key, value in pil_image.info.items():
|
51 |
+
if isinstance(key, str) and isinstance(value, str):
|
52 |
+
metadata.add_text(key, value)
|
53 |
+
use_metadata = True
|
54 |
+
|
55 |
+
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
56 |
+
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
|
57 |
+
return file_obj.name
|
58 |
+
|
59 |
+
|
60 |
+
# override save to file function so that it also writes PNG info
|
61 |
+
gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
|
62 |
+
|
63 |
+
|
64 |
+
def on_tmpdir_changed():
|
65 |
+
if shared.opts.temp_dir == "" or shared.demo is None:
|
66 |
+
return
|
67 |
+
|
68 |
+
os.makedirs(shared.opts.temp_dir, exist_ok=True)
|
69 |
+
|
70 |
+
register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
|
71 |
+
|
72 |
+
|
73 |
+
def cleanup_tmpdr():
|
74 |
+
temp_dir = shared.opts.temp_dir
|
75 |
+
if temp_dir == "" or not os.path.isdir(temp_dir):
|
76 |
+
return
|
77 |
+
|
78 |
+
for root, _, files in os.walk(temp_dir, topdown=False):
|
79 |
+
for name in files:
|
80 |
+
_, extension = os.path.splitext(name)
|
81 |
+
if extension != ".png":
|
82 |
+
continue
|
83 |
+
|
84 |
+
filename = os.path.join(root, name)
|
85 |
+
os.remove(filename)
|
modules/upscaler.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import abstractmethod
|
3 |
+
|
4 |
+
import PIL
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
import modules.shared
|
8 |
+
from modules import modelloader, shared
|
9 |
+
|
10 |
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
11 |
+
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
|
12 |
+
|
13 |
+
|
14 |
+
class Upscaler:
|
15 |
+
name = None
|
16 |
+
model_path = None
|
17 |
+
model_name = None
|
18 |
+
model_url = None
|
19 |
+
enable = True
|
20 |
+
filter = None
|
21 |
+
model = None
|
22 |
+
user_path = None
|
23 |
+
scalers: []
|
24 |
+
tile = True
|
25 |
+
|
26 |
+
def __init__(self, create_dirs=False):
|
27 |
+
self.mod_pad_h = None
|
28 |
+
self.tile_size = modules.shared.opts.ESRGAN_tile
|
29 |
+
self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
|
30 |
+
self.device = modules.shared.device
|
31 |
+
self.img = None
|
32 |
+
self.output = None
|
33 |
+
self.scale = 1
|
34 |
+
self.half = not modules.shared.cmd_opts.no_half
|
35 |
+
self.pre_pad = 0
|
36 |
+
self.mod_scale = None
|
37 |
+
self.model_download_path = None
|
38 |
+
|
39 |
+
if self.model_path is None and self.name:
|
40 |
+
self.model_path = os.path.join(shared.models_path, self.name)
|
41 |
+
if self.model_path and create_dirs:
|
42 |
+
os.makedirs(self.model_path, exist_ok=True)
|
43 |
+
|
44 |
+
try:
|
45 |
+
import cv2 # noqa: F401
|
46 |
+
self.can_tile = True
|
47 |
+
except Exception:
|
48 |
+
pass
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def do_upscale(self, img: PIL.Image, selected_model: str):
|
52 |
+
return img
|
53 |
+
|
54 |
+
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
|
55 |
+
self.scale = scale
|
56 |
+
dest_w = int((img.width * scale) // 8 * 8)
|
57 |
+
dest_h = int((img.height * scale) // 8 * 8)
|
58 |
+
|
59 |
+
for _ in range(3):
|
60 |
+
shape = (img.width, img.height)
|
61 |
+
|
62 |
+
img = self.do_upscale(img, selected_model)
|
63 |
+
|
64 |
+
if shape == (img.width, img.height):
|
65 |
+
break
|
66 |
+
|
67 |
+
if img.width >= dest_w and img.height >= dest_h:
|
68 |
+
break
|
69 |
+
|
70 |
+
if img.width != dest_w or img.height != dest_h:
|
71 |
+
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
|
72 |
+
|
73 |
+
return img
|
74 |
+
|
75 |
+
@abstractmethod
|
76 |
+
def load_model(self, path: str):
|
77 |
+
pass
|
78 |
+
|
79 |
+
def find_models(self, ext_filter=None) -> list:
|
80 |
+
return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter)
|
81 |
+
|
82 |
+
def update_status(self, prompt):
|
83 |
+
print(f"\nextras: {prompt}", file=shared.progress_print_out)
|
84 |
+
|
85 |
+
|
86 |
+
class UpscalerData:
|
87 |
+
name = None
|
88 |
+
data_path = None
|
89 |
+
scale: int = 4
|
90 |
+
scaler: Upscaler = None
|
91 |
+
model: None
|
92 |
+
|
93 |
+
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
|
94 |
+
self.name = name
|
95 |
+
self.data_path = path
|
96 |
+
self.local_data_path = path
|
97 |
+
self.scaler = upscaler
|
98 |
+
self.scale = scale
|
99 |
+
self.model = model
|
100 |
+
|
101 |
+
|
102 |
+
class UpscalerNone(Upscaler):
|
103 |
+
name = "None"
|
104 |
+
scalers = []
|
105 |
+
|
106 |
+
def load_model(self, path):
|
107 |
+
pass
|
108 |
+
|
109 |
+
def do_upscale(self, img, selected_model=None):
|
110 |
+
return img
|
111 |
+
|
112 |
+
def __init__(self, dirname=None):
|
113 |
+
super().__init__(False)
|
114 |
+
self.scalers = [UpscalerData("None", None, self)]
|
115 |
+
|
116 |
+
|
117 |
+
class UpscalerLanczos(Upscaler):
|
118 |
+
scalers = []
|
119 |
+
|
120 |
+
def do_upscale(self, img, selected_model=None):
|
121 |
+
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
|
122 |
+
|
123 |
+
def load_model(self, _):
|
124 |
+
pass
|
125 |
+
|
126 |
+
def __init__(self, dirname=None):
|
127 |
+
super().__init__(False)
|
128 |
+
self.name = "Lanczos"
|
129 |
+
self.scalers = [UpscalerData("Lanczos", None, self)]
|
130 |
+
|
131 |
+
|
132 |
+
class UpscalerNearest(Upscaler):
|
133 |
+
scalers = []
|
134 |
+
|
135 |
+
def do_upscale(self, img, selected_model=None):
|
136 |
+
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
|
137 |
+
|
138 |
+
def load_model(self, _):
|
139 |
+
pass
|
140 |
+
|
141 |
+
def __init__(self, dirname=None):
|
142 |
+
super().__init__(False)
|
143 |
+
self.name = "Nearest"
|
144 |
+
self.scalers = [UpscalerData("Nearest", None, self)]
|
modules/xlmr.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertPreTrainedModel, BertConfig
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
5 |
+
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
class BertSeriesConfig(BertConfig):
|
9 |
+
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
10 |
+
|
11 |
+
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
12 |
+
self.project_dim = project_dim
|
13 |
+
self.pooler_fn = pooler_fn
|
14 |
+
self.learn_encoder = learn_encoder
|
15 |
+
|
16 |
+
class RobertaSeriesConfig(XLMRobertaConfig):
|
17 |
+
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
18 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
19 |
+
self.project_dim = project_dim
|
20 |
+
self.pooler_fn = pooler_fn
|
21 |
+
self.learn_encoder = learn_encoder
|
22 |
+
|
23 |
+
|
24 |
+
class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
25 |
+
|
26 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
27 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
28 |
+
config_class = BertSeriesConfig
|
29 |
+
|
30 |
+
def __init__(self, config=None, **kargs):
|
31 |
+
# modify initialization for autoloading
|
32 |
+
if config is None:
|
33 |
+
config = XLMRobertaConfig()
|
34 |
+
config.attention_probs_dropout_prob= 0.1
|
35 |
+
config.bos_token_id=0
|
36 |
+
config.eos_token_id=2
|
37 |
+
config.hidden_act='gelu'
|
38 |
+
config.hidden_dropout_prob=0.1
|
39 |
+
config.hidden_size=1024
|
40 |
+
config.initializer_range=0.02
|
41 |
+
config.intermediate_size=4096
|
42 |
+
config.layer_norm_eps=1e-05
|
43 |
+
config.max_position_embeddings=514
|
44 |
+
|
45 |
+
config.num_attention_heads=16
|
46 |
+
config.num_hidden_layers=24
|
47 |
+
config.output_past=True
|
48 |
+
config.pad_token_id=1
|
49 |
+
config.position_embedding_type= "absolute"
|
50 |
+
|
51 |
+
config.type_vocab_size= 1
|
52 |
+
config.use_cache=True
|
53 |
+
config.vocab_size= 250002
|
54 |
+
config.project_dim = 768
|
55 |
+
config.learn_encoder = False
|
56 |
+
super().__init__(config)
|
57 |
+
self.roberta = XLMRobertaModel(config)
|
58 |
+
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
59 |
+
self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
60 |
+
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
61 |
+
self.pooler = lambda x: x[:,0]
|
62 |
+
self.post_init()
|
63 |
+
|
64 |
+
def encode(self,c):
|
65 |
+
device = next(self.parameters()).device
|
66 |
+
text = self.tokenizer(c,
|
67 |
+
truncation=True,
|
68 |
+
max_length=77,
|
69 |
+
return_length=False,
|
70 |
+
return_overflowing_tokens=False,
|
71 |
+
padding="max_length",
|
72 |
+
return_tensors="pt")
|
73 |
+
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
74 |
+
text["attention_mask"] = torch.tensor(
|
75 |
+
text['attention_mask']).to(device)
|
76 |
+
features = self(**text)
|
77 |
+
return features['projection_state']
|
78 |
+
|
79 |
+
def forward(
|
80 |
+
self,
|
81 |
+
input_ids: Optional[torch.Tensor] = None,
|
82 |
+
attention_mask: Optional[torch.Tensor] = None,
|
83 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
84 |
+
position_ids: Optional[torch.Tensor] = None,
|
85 |
+
head_mask: Optional[torch.Tensor] = None,
|
86 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
87 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
88 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
89 |
+
output_attentions: Optional[bool] = None,
|
90 |
+
return_dict: Optional[bool] = None,
|
91 |
+
output_hidden_states: Optional[bool] = None,
|
92 |
+
) :
|
93 |
+
r"""
|
94 |
+
"""
|
95 |
+
|
96 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
97 |
+
|
98 |
+
|
99 |
+
outputs = self.roberta(
|
100 |
+
input_ids=input_ids,
|
101 |
+
attention_mask=attention_mask,
|
102 |
+
token_type_ids=token_type_ids,
|
103 |
+
position_ids=position_ids,
|
104 |
+
head_mask=head_mask,
|
105 |
+
inputs_embeds=inputs_embeds,
|
106 |
+
encoder_hidden_states=encoder_hidden_states,
|
107 |
+
encoder_attention_mask=encoder_attention_mask,
|
108 |
+
output_attentions=output_attentions,
|
109 |
+
output_hidden_states=True,
|
110 |
+
return_dict=return_dict,
|
111 |
+
)
|
112 |
+
|
113 |
+
# last module outputs
|
114 |
+
sequence_output = outputs[0]
|
115 |
+
|
116 |
+
|
117 |
+
# project every module
|
118 |
+
sequence_output_ln = self.pre_LN(sequence_output)
|
119 |
+
|
120 |
+
# pooler
|
121 |
+
pooler_output = self.pooler(sequence_output_ln)
|
122 |
+
pooler_output = self.transformation(pooler_output)
|
123 |
+
projection_state = self.transformation(outputs.last_hidden_state)
|
124 |
+
|
125 |
+
return {
|
126 |
+
'pooler_output':pooler_output,
|
127 |
+
'last_hidden_state':outputs.last_hidden_state,
|
128 |
+
'hidden_states':outputs.hidden_states,
|
129 |
+
'attentions':outputs.attentions,
|
130 |
+
'projection_state':projection_state,
|
131 |
+
'sequence_out': sequence_output
|
132 |
+
}
|
133 |
+
|
134 |
+
|
135 |
+
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
136 |
+
base_model_prefix = 'roberta'
|
137 |
+
config_class= RobertaSeriesConfig
|
outputs/txt2img-images/2023-07-30/00000-4104476258.png
ADDED
outputs/txt2img-images/2023-07-30/00001-1264812310.png
ADDED
outputs/txt2img-images/2023-07-30/00002-629074369.png
ADDED
outputs/txt2img-images/2023-07-30/00003-3929529382.png
ADDED
outputs/txt2img-images/2023-07-30/00004-2891905160.png
ADDED
outputs/txt2img-images/2023-07-30/00005-1703927525.png
ADDED
outputs/txt2img-images/2023-07-30/00006-1703927525.png
ADDED
outputs/txt2img-images/2023-07-30/00007-1703927525.png
ADDED
outputs/txt2img-images/2023-07-30/00008-1703927525.png
ADDED
outputs/txt2img-images/2023-07-30/00009-1703927525.jpg
ADDED
outputs/txt2img-images/2023-07-30/00010-210755578.jpg
ADDED
outputs/txt2img-images/2023-07-30/00011-3978311133.jpg
ADDED
outputs/txt2img-images/2023-07-30/00012-3786155085.jpg
ADDED
outputs/txt2img-images/2023-07-30/00013-445379948.jpg
ADDED
outputs/txt2img-images/2023-07-30/00014-3277595636.png
ADDED
package.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "stable-diffusion-webui",
|
3 |
+
"version": "0.0.0",
|
4 |
+
"devDependencies": {
|
5 |
+
"eslint": "^8.40.0"
|
6 |
+
},
|
7 |
+
"scripts": {
|
8 |
+
"lint": "eslint .",
|
9 |
+
"fix": "eslint --fix ."
|
10 |
+
}
|
11 |
+
}
|
params.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
masterpiece, best quality,
|
2 |
+
Negative prompt: (worst quality, low quality:1.4)
|
3 |
+
Steps: 20, Sampler: DPM++ 2M Karras, CFG scale: 7, Seed: 3277595636, Size: 512x768, Model hash: ee5e7d0285, Model: SCH_Excelsior, Version: 1.5.0
|
pyproject.toml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.ruff]
|
2 |
+
|
3 |
+
target-version = "py39"
|
4 |
+
|
5 |
+
extend-select = [
|
6 |
+
"B",
|
7 |
+
"C",
|
8 |
+
"I",
|
9 |
+
"W",
|
10 |
+
]
|
11 |
+
|
12 |
+
exclude = [
|
13 |
+
"extensions",
|
14 |
+
"extensions-disabled",
|
15 |
+
]
|
16 |
+
|
17 |
+
ignore = [
|
18 |
+
"E501", # Line too long
|
19 |
+
"E731", # Do not assign a `lambda` expression, use a `def`
|
20 |
+
|
21 |
+
"I001", # Import block is un-sorted or un-formatted
|
22 |
+
"C901", # Function is too complex
|
23 |
+
"C408", # Rewrite as a literal
|
24 |
+
"W605", # invalid escape sequence, messes with some docstrings
|
25 |
+
]
|
26 |
+
|
27 |
+
[tool.ruff.per-file-ignores]
|
28 |
+
"webui.py" = ["E402"] # Module level import not at top of file
|
29 |
+
|
30 |
+
[tool.ruff.flake8-bugbear]
|
31 |
+
# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
|
32 |
+
extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]
|
33 |
+
|
34 |
+
[tool.pytest.ini_options]
|
35 |
+
base_url = "http://127.0.0.1:7860"
|
repositories/BLIP/BLIP.gif
ADDED
Git LFS Details
|
repositories/BLIP/CODEOWNERS
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
|
2 |
+
#ECCN:Open Source
|
repositories/BLIP/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Salesforce Open Source Community Code of Conduct
|
2 |
+
|
3 |
+
## About the Code of Conduct
|
4 |
+
|
5 |
+
Equality is a core value at Salesforce. We believe a diverse and inclusive
|
6 |
+
community fosters innovation and creativity, and are committed to building a
|
7 |
+
culture where everyone feels included.
|
8 |
+
|
9 |
+
Salesforce open-source projects are committed to providing a friendly, safe, and
|
10 |
+
welcoming environment for all, regardless of gender identity and expression,
|
11 |
+
sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
|
12 |
+
race, age, religion, level of experience, education, socioeconomic status, or
|
13 |
+
other similar personal characteristics.
|
14 |
+
|
15 |
+
The goal of this code of conduct is to specify a baseline standard of behavior so
|
16 |
+
that people with different social values and communication styles can work
|
17 |
+
together effectively, productively, and respectfully in our open source community.
|
18 |
+
It also establishes a mechanism for reporting issues and resolving conflicts.
|
19 |
+
|
20 |
+
All questions and reports of abusive, harassing, or otherwise unacceptable behavior
|
21 |
+
in a Salesforce open-source project may be reported by contacting the Salesforce
|
22 |
+
Open Source Conduct Committee at ossconduct@salesforce.com.
|
23 |
+
|
24 |
+
## Our Pledge
|
25 |
+
|
26 |
+
In the interest of fostering an open and welcoming environment, we as
|
27 |
+
contributors and maintainers pledge to making participation in our project and
|
28 |
+
our community a harassment-free experience for everyone, regardless of gender
|
29 |
+
identity and expression, sexual orientation, disability, physical appearance,
|
30 |
+
body size, ethnicity, nationality, race, age, religion, level of experience, education,
|
31 |
+
socioeconomic status, or other similar personal characteristics.
|
32 |
+
|
33 |
+
## Our Standards
|
34 |
+
|
35 |
+
Examples of behavior that contributes to creating a positive environment
|
36 |
+
include:
|
37 |
+
|
38 |
+
* Using welcoming and inclusive language
|
39 |
+
* Being respectful of differing viewpoints and experiences
|
40 |
+
* Gracefully accepting constructive criticism
|
41 |
+
* Focusing on what is best for the community
|
42 |
+
* Showing empathy toward other community members
|
43 |
+
|
44 |
+
Examples of unacceptable behavior by participants include:
|
45 |
+
|
46 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
47 |
+
advances
|
48 |
+
* Personal attacks, insulting/derogatory comments, or trolling
|
49 |
+
* Public or private harassment
|
50 |
+
* Publishing, or threatening to publish, others' private information—such as
|
51 |
+
a physical or electronic address—without explicit permission
|
52 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
53 |
+
professional setting
|
54 |
+
* Advocating for or encouraging any of the above behaviors
|
55 |
+
|
56 |
+
## Our Responsibilities
|
57 |
+
|
58 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
59 |
+
behavior and are expected to take appropriate and fair corrective action in
|
60 |
+
response to any instances of unacceptable behavior.
|
61 |
+
|
62 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
63 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
64 |
+
that are not aligned with this Code of Conduct, or to ban temporarily or
|
65 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
66 |
+
threatening, offensive, or harmful.
|
67 |
+
|
68 |
+
## Scope
|
69 |
+
|
70 |
+
This Code of Conduct applies both within project spaces and in public spaces
|
71 |
+
when an individual is representing the project or its community. Examples of
|
72 |
+
representing a project or community include using an official project email
|
73 |
+
address, posting via an official social media account, or acting as an appointed
|
74 |
+
representative at an online or offline event. Representation of a project may be
|
75 |
+
further defined and clarified by project maintainers.
|
76 |
+
|
77 |
+
## Enforcement
|
78 |
+
|
79 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
80 |
+
reported by contacting the Salesforce Open Source Conduct Committee
|
81 |
+
at ossconduct@salesforce.com. All complaints will be reviewed and investigated
|
82 |
+
and will result in a response that is deemed necessary and appropriate to the
|
83 |
+
circumstances. The committee is obligated to maintain confidentiality with
|
84 |
+
regard to the reporter of an incident. Further details of specific enforcement
|
85 |
+
policies may be posted separately.
|
86 |
+
|
87 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
88 |
+
faith may face temporary or permanent repercussions as determined by other
|
89 |
+
members of the project's leadership and the Salesforce Open Source Conduct
|
90 |
+
Committee.
|
91 |
+
|
92 |
+
## Attribution
|
93 |
+
|
94 |
+
This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
|
95 |
+
version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
|
96 |
+
It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
|
97 |
+
[CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
|
98 |
+
|
99 |
+
This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
|
100 |
+
|
101 |
+
[contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
|
102 |
+
[golang-coc]: https://golang.org/conduct
|
103 |
+
[cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
|
104 |
+
[microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
|
105 |
+
[cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
|
repositories/BLIP/LICENSE.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2022, Salesforce.com, Inc.
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
5 |
+
|
6 |
+
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
7 |
+
|
8 |
+
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
9 |
+
|
10 |
+
* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
11 |
+
|
12 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
repositories/BLIP/README.md
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
|
2 |
+
|
3 |
+
<img src="BLIP.gif" width="700">
|
4 |
+
|
5 |
+
This is the PyTorch code of the <a href="https://arxiv.org/abs/2201.12086">BLIP paper</a> [[blog](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)]. The code has been tested on PyTorch 1.10.
|
6 |
+
To install the dependencies, run <pre/>pip install -r requirements.txt</pre>
|
7 |
+
|
8 |
+
Catalog:
|
9 |
+
- [x] Inference demo
|
10 |
+
- [x] Pre-trained and finetuned checkpoints
|
11 |
+
- [x] Finetuning code for Image-Text Retrieval, Image Captioning, VQA, and NLVR2
|
12 |
+
- [x] Pre-training code
|
13 |
+
- [x] Zero-shot video-text retrieval
|
14 |
+
- [x] Download of bootstrapped pre-training datasets
|
15 |
+
|
16 |
+
|
17 |
+
### Inference demo:
|
18 |
+
Run our interactive demo using [Colab notebook](https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb) (no GPU needed).
|
19 |
+
The demo includes code for:
|
20 |
+
1. Image captioning
|
21 |
+
2. Open-ended visual question answering
|
22 |
+
3. Multimodal / unimodal feature extraction
|
23 |
+
4. Image-text matching
|
24 |
+
|
25 |
+
Try out the [Web demo](https://huggingface.co/spaces/Salesforce/BLIP), integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio).
|
26 |
+
|
27 |
+
Replicate web demo and Docker image is also available at [![Replicate](https://replicate.com/salesforce/blip/badge)](https://replicate.com/salesforce/blip)
|
28 |
+
|
29 |
+
### Pre-trained checkpoints:
|
30 |
+
Num. pre-train images | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
|
31 |
+
--- | :---: | :---: | :---:
|
32 |
+
14M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth">Download</a>| - | -
|
33 |
+
129M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth">Download</a> | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth">Download</a>
|
34 |
+
|
35 |
+
### Finetuned checkpoints:
|
36 |
+
Task | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
|
37 |
+
--- | :---: | :---: | :---:
|
38 |
+
Image-Text Retrieval (COCO) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth">Download</a>
|
39 |
+
Image-Text Retrieval (Flickr30k) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_flickr.pth">Download</a>
|
40 |
+
Image Captioning (COCO) | - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth">Download</a> |
|
41 |
+
VQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth">Download</a> | -
|
42 |
+
NLVR2 | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth">Download</a>| - | -
|
43 |
+
|
44 |
+
|
45 |
+
### Image-Text Retrieval:
|
46 |
+
1. Download COCO and Flickr30k datasets from the original websites, and set 'image_root' in configs/retrieval_{dataset}.yaml accordingly.
|
47 |
+
2. To evaluate the finetuned BLIP model on COCO, run:
|
48 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
|
49 |
+
--config ./configs/retrieval_coco.yaml \
|
50 |
+
--output_dir output/retrieval_coco \
|
51 |
+
--evaluate</pre>
|
52 |
+
3. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/retrieval_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
|
53 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
|
54 |
+
--config ./configs/retrieval_coco.yaml \
|
55 |
+
--output_dir output/retrieval_coco </pre>
|
56 |
+
|
57 |
+
### Image-Text Captioning:
|
58 |
+
1. Download COCO and NoCaps datasets from the original websites, and set 'image_root' in configs/caption_coco.yaml and configs/nocaps.yaml accordingly.
|
59 |
+
2. To evaluate the finetuned BLIP model on COCO, run:
|
60 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py --evaluate</pre>
|
61 |
+
3. To evaluate the finetuned BLIP model on NoCaps, generate results with: (evaluation needs to be performed on official server)
|
62 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 eval_nocaps.py </pre>
|
63 |
+
4. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/caption_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
|
64 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py </pre>
|
65 |
+
|
66 |
+
### VQA:
|
67 |
+
1. Download VQA v2 dataset and Visual Genome dataset from the original websites, and set 'vqa_root' and 'vg_root' in configs/vqa.yaml.
|
68 |
+
2. To evaluate the finetuned BLIP model, generate results with: (evaluation needs to be performed on official server)
|
69 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 train_vqa.py --evaluate</pre>
|
70 |
+
3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/vqa.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
|
71 |
+
<pre>python -m torch.distributed.run --nproc_per_node=16 train_vqa.py </pre>
|
72 |
+
|
73 |
+
### NLVR2:
|
74 |
+
1. Download NLVR2 dataset from the original websites, and set 'image_root' in configs/nlvr.yaml.
|
75 |
+
2. To evaluate the finetuned BLIP model, run
|
76 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 train_nlvr.py --evaluate</pre>
|
77 |
+
3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/nlvr.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
|
78 |
+
<pre>python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py </pre>
|
79 |
+
|
80 |
+
### Finetune with ViT-L:
|
81 |
+
In order to finetune a model with ViT-L, simply change the config file to set 'vit' as large. Batch size and learning rate may also need to be adjusted accordingly (please see the paper's appendix for hyper-parameter details). <a href="https://github.com/facebookresearch/fairscale">Gradient checkpoint</a> can also be activated in the config file to reduce GPU memory usage.
|
82 |
+
|
83 |
+
### Pre-train:
|
84 |
+
1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
|
85 |
+
2. In configs/pretrain.yaml, set 'train_file' as the paths for the json files .
|
86 |
+
3. Pre-train the model using 8 A100 GPUs:
|
87 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain </pre>
|
88 |
+
|
89 |
+
### Zero-shot video-text retrieval:
|
90 |
+
1. Download MSRVTT dataset following the instructions from https://github.com/salesforce/ALPRO, and set 'video_root' accordingly in configs/retrieval_msrvtt.yaml.
|
91 |
+
2. Install [decord](https://github.com/dmlc/decord) with <pre>pip install decord</pre>
|
92 |
+
3. To perform zero-shot evaluation, run
|
93 |
+
<pre>python -m torch.distributed.run --nproc_per_node=8 eval_retrieval_video.py</pre>
|
94 |
+
|
95 |
+
### Pre-training datasets download:
|
96 |
+
We provide bootstrapped pre-training datasets as json files. Each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'url': url_of_image, 'caption': text_of_image}.
|
97 |
+
|
98 |
+
Image source | Filtered web caption | Filtered synthetic caption by ViT-B | Filtered synthetic caption by ViT-L
|
99 |
+
--- | :---: | :---: | :---:
|
100 |
+
CC3M+CC12M+SBU | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered_large.json">Download</a>
|
101 |
+
LAION115M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered_large.json">Download</a>
|
102 |
+
|
103 |
+
### Citation
|
104 |
+
If you find this code to be useful for your research, please consider citing.
|
105 |
+
<pre>
|
106 |
+
@inproceedings{li2022blip,
|
107 |
+
title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation},
|
108 |
+
author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
|
109 |
+
year={2022},
|
110 |
+
booktitle={ICML},
|
111 |
+
}</pre>
|
112 |
+
|
113 |
+
### Acknowledgement
|
114 |
+
The implementation of BLIP relies on resources from <a href="https://github.com/salesforce/ALBEF">ALBEF</a>, <a href="https://github.com/huggingface/transformers">Huggingface Transformers</a>, and <a href="https://github.com/rwightman/pytorch-image-models/tree/master/timm">timm</a>. We thank the original authors for their open-sourcing.
|
repositories/BLIP/SECURITY.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Security
|
2 |
+
|
3 |
+
Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
|
4 |
+
as soon as it is discovered. This library limits its runtime dependencies in
|
5 |
+
order to reduce the total cost of ownership as much as can be, but all consumers
|
6 |
+
should remain vigilant and have their security stakeholders review all third-party
|
7 |
+
products (3PP) like this one and their dependencies.
|
repositories/BLIP/cog.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
build:
|
2 |
+
gpu: true
|
3 |
+
cuda: "11.1"
|
4 |
+
python_version: "3.8"
|
5 |
+
system_packages:
|
6 |
+
- "libgl1-mesa-glx"
|
7 |
+
- "libglib2.0-0"
|
8 |
+
python_packages:
|
9 |
+
- "ipython==7.30.1"
|
10 |
+
- "torchvision==0.11.1"
|
11 |
+
- "torch==1.10.0"
|
12 |
+
- "timm==0.4.12"
|
13 |
+
- "transformers==4.15.0"
|
14 |
+
- "fairscale==0.4.4"
|
15 |
+
- "pycocoevalcap==1.2"
|
16 |
+
|
17 |
+
predict: "predict.py:Predictor"
|
repositories/BLIP/configs/bert_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"type_vocab_size": 2,
|
18 |
+
"vocab_size": 30522,
|
19 |
+
"encoder_width": 768,
|
20 |
+
"add_cross_attention": true
|
21 |
+
}
|