toto10 commited on
Commit
e9c770e
1 Parent(s): 449cca0

13f5d210b8c95ad6b63872633d84859a0a6d9a4258ae3d0e1976b3e737b46fea

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. modules/textual_inversion/logging.py +64 -0
  3. modules/textual_inversion/preprocess.py +232 -0
  4. modules/textual_inversion/test_embedding.png +0 -0
  5. modules/textual_inversion/textual_inversion.py +683 -0
  6. modules/textual_inversion/ui.py +45 -0
  7. modules/timer.py +91 -0
  8. modules/txt2img.py +73 -0
  9. modules/ui.py +0 -0
  10. modules/ui_common.py +244 -0
  11. modules/ui_components.py +74 -0
  12. modules/ui_extensions.py +651 -0
  13. modules/ui_extra_networks.py +496 -0
  14. modules/ui_extra_networks_checkpoints.py +35 -0
  15. modules/ui_extra_networks_hypernets.py +35 -0
  16. modules/ui_extra_networks_textual_inversion.py +35 -0
  17. modules/ui_extra_networks_user_metadata.py +195 -0
  18. modules/ui_gradio_extensions.py +69 -0
  19. modules/ui_loadsave.py +210 -0
  20. modules/ui_postprocessing.py +57 -0
  21. modules/ui_settings.py +296 -0
  22. modules/ui_tempdir.py +85 -0
  23. modules/upscaler.py +144 -0
  24. modules/xlmr.py +137 -0
  25. outputs/txt2img-images/2023-07-30/00000-4104476258.png +0 -0
  26. outputs/txt2img-images/2023-07-30/00001-1264812310.png +0 -0
  27. outputs/txt2img-images/2023-07-30/00002-629074369.png +0 -0
  28. outputs/txt2img-images/2023-07-30/00003-3929529382.png +0 -0
  29. outputs/txt2img-images/2023-07-30/00004-2891905160.png +0 -0
  30. outputs/txt2img-images/2023-07-30/00005-1703927525.png +0 -0
  31. outputs/txt2img-images/2023-07-30/00006-1703927525.png +0 -0
  32. outputs/txt2img-images/2023-07-30/00007-1703927525.png +0 -0
  33. outputs/txt2img-images/2023-07-30/00008-1703927525.png +0 -0
  34. outputs/txt2img-images/2023-07-30/00009-1703927525.jpg +0 -0
  35. outputs/txt2img-images/2023-07-30/00010-210755578.jpg +0 -0
  36. outputs/txt2img-images/2023-07-30/00011-3978311133.jpg +0 -0
  37. outputs/txt2img-images/2023-07-30/00012-3786155085.jpg +0 -0
  38. outputs/txt2img-images/2023-07-30/00013-445379948.jpg +0 -0
  39. outputs/txt2img-images/2023-07-30/00014-3277595636.png +0 -0
  40. package.json +11 -0
  41. params.txt +3 -0
  42. pyproject.toml +35 -0
  43. repositories/BLIP/BLIP.gif +3 -0
  44. repositories/BLIP/CODEOWNERS +2 -0
  45. repositories/BLIP/CODE_OF_CONDUCT.md +105 -0
  46. repositories/BLIP/LICENSE.txt +12 -0
  47. repositories/BLIP/README.md +114 -0
  48. repositories/BLIP/SECURITY.md +7 -0
  49. repositories/BLIP/cog.yaml +17 -0
  50. repositories/BLIP/configs/bert_config.json +21 -0
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  extensions/Stable-Diffusion-Webui-Civitai-Helper/img/all_in_one.png filter=lfs diff=lfs merge=lfs -text
37
  extensions/addtional/models/lora/README.md filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  extensions/Stable-Diffusion-Webui-Civitai-Helper/img/all_in_one.png filter=lfs diff=lfs merge=lfs -text
37
  extensions/addtional/models/lora/README.md filter=lfs diff=lfs merge=lfs -text
38
+ repositories/BLIP/BLIP.gif filter=lfs diff=lfs merge=lfs -text
modules/textual_inversion/logging.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+ import os
4
+
5
+ saved_params_shared = {
6
+ "batch_size",
7
+ "clip_grad_mode",
8
+ "clip_grad_value",
9
+ "create_image_every",
10
+ "data_root",
11
+ "gradient_step",
12
+ "initial_step",
13
+ "latent_sampling_method",
14
+ "learn_rate",
15
+ "log_directory",
16
+ "model_hash",
17
+ "model_name",
18
+ "num_of_dataset_images",
19
+ "steps",
20
+ "template_file",
21
+ "training_height",
22
+ "training_width",
23
+ }
24
+ saved_params_ti = {
25
+ "embedding_name",
26
+ "num_vectors_per_token",
27
+ "save_embedding_every",
28
+ "save_image_with_stored_embedding",
29
+ }
30
+ saved_params_hypernet = {
31
+ "activation_func",
32
+ "add_layer_norm",
33
+ "hypernetwork_name",
34
+ "layer_structure",
35
+ "save_hypernetwork_every",
36
+ "use_dropout",
37
+ "weight_init",
38
+ }
39
+ saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
40
+ saved_params_previews = {
41
+ "preview_cfg_scale",
42
+ "preview_height",
43
+ "preview_negative_prompt",
44
+ "preview_prompt",
45
+ "preview_sampler_index",
46
+ "preview_seed",
47
+ "preview_steps",
48
+ "preview_width",
49
+ }
50
+
51
+
52
+ def save_settings_to_file(log_directory, all_params):
53
+ now = datetime.datetime.now()
54
+ params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
55
+
56
+ keys = saved_params_all
57
+ if all_params.get('preview_from_txt2img'):
58
+ keys = keys | saved_params_previews
59
+
60
+ params.update({k: v for k, v in all_params.items() if k in keys})
61
+
62
+ filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json'
63
+ with open(os.path.join(log_directory, filename), "w") as file:
64
+ json.dump(params, file, indent=4)
modules/textual_inversion/preprocess.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image, ImageOps
3
+ import math
4
+ import tqdm
5
+
6
+ from modules import paths, shared, images, deepbooru
7
+ from modules.textual_inversion import autocrop
8
+
9
+
10
+ def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
11
+ try:
12
+ if process_caption:
13
+ shared.interrogator.load()
14
+
15
+ if process_caption_deepbooru:
16
+ deepbooru.model.start()
17
+
18
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
19
+
20
+ finally:
21
+
22
+ if process_caption:
23
+ shared.interrogator.send_blip_to_ram()
24
+
25
+ if process_caption_deepbooru:
26
+ deepbooru.model.stop()
27
+
28
+
29
+ def listfiles(dirname):
30
+ return os.listdir(dirname)
31
+
32
+
33
+ class PreprocessParams:
34
+ src = None
35
+ dstdir = None
36
+ subindex = 0
37
+ flip = False
38
+ process_caption = False
39
+ process_caption_deepbooru = False
40
+ preprocess_txt_action = None
41
+
42
+
43
+ def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
44
+ caption = ""
45
+
46
+ if params.process_caption:
47
+ caption += shared.interrogator.generate_caption(image)
48
+
49
+ if params.process_caption_deepbooru:
50
+ if caption:
51
+ caption += ", "
52
+ caption += deepbooru.model.tag_multi(image)
53
+
54
+ filename_part = params.src
55
+ filename_part = os.path.splitext(filename_part)[0]
56
+ filename_part = os.path.basename(filename_part)
57
+
58
+ basename = f"{index:05}-{params.subindex}-{filename_part}"
59
+ image.save(os.path.join(params.dstdir, f"{basename}.png"))
60
+
61
+ if params.preprocess_txt_action == 'prepend' and existing_caption:
62
+ caption = f"{existing_caption} {caption}"
63
+ elif params.preprocess_txt_action == 'append' and existing_caption:
64
+ caption = f"{caption} {existing_caption}"
65
+ elif params.preprocess_txt_action == 'copy' and existing_caption:
66
+ caption = existing_caption
67
+
68
+ caption = caption.strip()
69
+
70
+ if caption:
71
+ with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
72
+ file.write(caption)
73
+
74
+ params.subindex += 1
75
+
76
+
77
+ def save_pic(image, index, params, existing_caption=None):
78
+ save_pic_with_caption(image, index, params, existing_caption=existing_caption)
79
+
80
+ if params.flip:
81
+ save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
82
+
83
+
84
+ def split_pic(image, inverse_xy, width, height, overlap_ratio):
85
+ if inverse_xy:
86
+ from_w, from_h = image.height, image.width
87
+ to_w, to_h = height, width
88
+ else:
89
+ from_w, from_h = image.width, image.height
90
+ to_w, to_h = width, height
91
+ h = from_h * to_w // from_w
92
+ if inverse_xy:
93
+ image = image.resize((h, to_w))
94
+ else:
95
+ image = image.resize((to_w, h))
96
+
97
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
98
+ y_step = (h - to_h) / (split_count - 1)
99
+ for i in range(split_count):
100
+ y = int(y_step * i)
101
+ if inverse_xy:
102
+ splitted = image.crop((y, 0, y + to_h, to_w))
103
+ else:
104
+ splitted = image.crop((0, y, to_w, y + to_h))
105
+ yield splitted
106
+
107
+ # not using torchvision.transforms.CenterCrop because it doesn't allow float regions
108
+ def center_crop(image: Image, w: int, h: int):
109
+ iw, ih = image.size
110
+ if ih / h < iw / w:
111
+ sw = w * ih / h
112
+ box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
113
+ else:
114
+ sh = h * iw / w
115
+ box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
116
+ return image.resize((w, h), Image.Resampling.LANCZOS, box)
117
+
118
+
119
+ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
120
+ iw, ih = image.size
121
+ err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
122
+ wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
123
+ if minarea <= w * h <= maxarea and err(w, h) <= threshold),
124
+ key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
125
+ default=None
126
+ )
127
+ return wh and center_crop(image, *wh)
128
+
129
+
130
+ def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
131
+ width = process_width
132
+ height = process_height
133
+ src = os.path.abspath(process_src)
134
+ dst = os.path.abspath(process_dst)
135
+ split_threshold = max(0.0, min(1.0, split_threshold))
136
+ overlap_ratio = max(0.0, min(0.9, overlap_ratio))
137
+
138
+ assert src != dst, 'same directory specified as source and destination'
139
+
140
+ os.makedirs(dst, exist_ok=True)
141
+
142
+ files = listfiles(src)
143
+
144
+ shared.state.job = "preprocess"
145
+ shared.state.textinfo = "Preprocessing..."
146
+ shared.state.job_count = len(files)
147
+
148
+ params = PreprocessParams()
149
+ params.dstdir = dst
150
+ params.flip = process_flip
151
+ params.process_caption = process_caption
152
+ params.process_caption_deepbooru = process_caption_deepbooru
153
+ params.preprocess_txt_action = preprocess_txt_action
154
+
155
+ pbar = tqdm.tqdm(files)
156
+ for index, imagefile in enumerate(pbar):
157
+ params.subindex = 0
158
+ filename = os.path.join(src, imagefile)
159
+ try:
160
+ img = Image.open(filename)
161
+ img = ImageOps.exif_transpose(img)
162
+ img = img.convert("RGB")
163
+ except Exception:
164
+ continue
165
+
166
+ description = f"Preprocessing [Image {index}/{len(files)}]"
167
+ pbar.set_description(description)
168
+ shared.state.textinfo = description
169
+
170
+ params.src = filename
171
+
172
+ existing_caption = None
173
+ existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt"
174
+ if os.path.exists(existing_caption_filename):
175
+ with open(existing_caption_filename, 'r', encoding="utf8") as file:
176
+ existing_caption = file.read()
177
+
178
+ if shared.state.interrupted:
179
+ break
180
+
181
+ if img.height > img.width:
182
+ ratio = (img.width * height) / (img.height * width)
183
+ inverse_xy = False
184
+ else:
185
+ ratio = (img.height * width) / (img.width * height)
186
+ inverse_xy = True
187
+
188
+ process_default_resize = True
189
+
190
+ if process_split and ratio < 1.0 and ratio <= split_threshold:
191
+ for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
192
+ save_pic(splitted, index, params, existing_caption=existing_caption)
193
+ process_default_resize = False
194
+
195
+ if process_focal_crop and img.height != img.width:
196
+
197
+ dnn_model_path = None
198
+ try:
199
+ dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
200
+ except Exception as e:
201
+ print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
202
+
203
+ autocrop_settings = autocrop.Settings(
204
+ crop_width = width,
205
+ crop_height = height,
206
+ face_points_weight = process_focal_crop_face_weight,
207
+ entropy_points_weight = process_focal_crop_entropy_weight,
208
+ corner_points_weight = process_focal_crop_edges_weight,
209
+ annotate_image = process_focal_crop_debug,
210
+ dnn_model_path = dnn_model_path,
211
+ )
212
+ for focal in autocrop.crop_image(img, autocrop_settings):
213
+ save_pic(focal, index, params, existing_caption=existing_caption)
214
+ process_default_resize = False
215
+
216
+ if process_multicrop:
217
+ cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
218
+ if cropped is not None:
219
+ save_pic(cropped, index, params, existing_caption=existing_caption)
220
+ else:
221
+ print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
222
+ process_default_resize = False
223
+
224
+ if process_keep_original_size:
225
+ save_pic(img, index, params, existing_caption=existing_caption)
226
+ process_default_resize = False
227
+
228
+ if process_default_resize:
229
+ img = images.resize_image(1, img, width, height)
230
+ save_pic(img, index, params, existing_caption=existing_caption)
231
+
232
+ shared.state.nextjob()
modules/textual_inversion/test_embedding.png ADDED
modules/textual_inversion/textual_inversion.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import namedtuple
3
+ from contextlib import closing
4
+
5
+ import torch
6
+ import tqdm
7
+ import html
8
+ import datetime
9
+ import csv
10
+ import safetensors.torch
11
+
12
+ import numpy as np
13
+ from PIL import Image, PngImagePlugin
14
+ from torch.utils.tensorboard import SummaryWriter
15
+
16
+ from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
17
+ import modules.textual_inversion.dataset
18
+ from modules.textual_inversion.learn_schedule import LearnRateScheduler
19
+
20
+ from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
21
+ from modules.textual_inversion.logging import save_settings_to_file
22
+
23
+
24
+ TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
25
+ textual_inversion_templates = {}
26
+
27
+
28
+ def list_textual_inversion_templates():
29
+ textual_inversion_templates.clear()
30
+
31
+ for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
32
+ for fn in fns:
33
+ path = os.path.join(root, fn)
34
+
35
+ textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
36
+
37
+ return textual_inversion_templates
38
+
39
+
40
+ class Embedding:
41
+ def __init__(self, vec, name, step=None):
42
+ self.vec = vec
43
+ self.name = name
44
+ self.step = step
45
+ self.shape = None
46
+ self.vectors = 0
47
+ self.cached_checksum = None
48
+ self.sd_checkpoint = None
49
+ self.sd_checkpoint_name = None
50
+ self.optimizer_state_dict = None
51
+ self.filename = None
52
+ self.hash = None
53
+ self.shorthash = None
54
+
55
+ def save(self, filename):
56
+ embedding_data = {
57
+ "string_to_token": {"*": 265},
58
+ "string_to_param": {"*": self.vec},
59
+ "name": self.name,
60
+ "step": self.step,
61
+ "sd_checkpoint": self.sd_checkpoint,
62
+ "sd_checkpoint_name": self.sd_checkpoint_name,
63
+ }
64
+
65
+ torch.save(embedding_data, filename)
66
+
67
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
68
+ optimizer_saved_dict = {
69
+ 'hash': self.checksum(),
70
+ 'optimizer_state_dict': self.optimizer_state_dict,
71
+ }
72
+ torch.save(optimizer_saved_dict, f"{filename}.optim")
73
+
74
+ def checksum(self):
75
+ if self.cached_checksum is not None:
76
+ return self.cached_checksum
77
+
78
+ def const_hash(a):
79
+ r = 0
80
+ for v in a:
81
+ r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
82
+ return r
83
+
84
+ self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
85
+ return self.cached_checksum
86
+
87
+ def set_hash(self, v):
88
+ self.hash = v
89
+ self.shorthash = self.hash[0:12]
90
+
91
+
92
+ class DirWithTextualInversionEmbeddings:
93
+ def __init__(self, path):
94
+ self.path = path
95
+ self.mtime = None
96
+
97
+ def has_changed(self):
98
+ if not os.path.isdir(self.path):
99
+ return False
100
+
101
+ mt = os.path.getmtime(self.path)
102
+ if self.mtime is None or mt > self.mtime:
103
+ return True
104
+
105
+ def update(self):
106
+ if not os.path.isdir(self.path):
107
+ return
108
+
109
+ self.mtime = os.path.getmtime(self.path)
110
+
111
+
112
+ class EmbeddingDatabase:
113
+ def __init__(self):
114
+ self.ids_lookup = {}
115
+ self.word_embeddings = {}
116
+ self.skipped_embeddings = {}
117
+ self.expected_shape = -1
118
+ self.embedding_dirs = {}
119
+ self.previously_displayed_embeddings = ()
120
+
121
+ def add_embedding_dir(self, path):
122
+ self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
123
+
124
+ def clear_embedding_dirs(self):
125
+ self.embedding_dirs.clear()
126
+
127
+ def register_embedding(self, embedding, model):
128
+ return self.register_embedding_by_name(embedding, model, embedding.name)
129
+
130
+ def register_embedding_by_name(self, embedding, model, name):
131
+ ids = model.cond_stage_model.tokenize([name])[0]
132
+ first_id = ids[0]
133
+ if first_id not in self.ids_lookup:
134
+ self.ids_lookup[first_id] = []
135
+ if name in self.word_embeddings:
136
+ # remove old one from the lookup list
137
+ lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
138
+ else:
139
+ lookup = self.ids_lookup[first_id]
140
+ if embedding is not None:
141
+ lookup += [(ids, embedding)]
142
+ self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
143
+ if embedding is None:
144
+ # unregister embedding with specified name
145
+ if name in self.word_embeddings:
146
+ del self.word_embeddings[name]
147
+ if len(self.ids_lookup[first_id])==0:
148
+ del self.ids_lookup[first_id]
149
+ return None
150
+ self.word_embeddings[name] = embedding
151
+ return embedding
152
+
153
+ def get_expected_shape(self):
154
+ vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
155
+ return vec.shape[1]
156
+
157
+ def load_from_file(self, path, filename):
158
+ name, ext = os.path.splitext(filename)
159
+ ext = ext.upper()
160
+
161
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
162
+ _, second_ext = os.path.splitext(name)
163
+ if second_ext.upper() == '.PREVIEW':
164
+ return
165
+
166
+ embed_image = Image.open(path)
167
+ if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
168
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
169
+ name = data.get('name', name)
170
+ else:
171
+ data = extract_image_data_embed(embed_image)
172
+ if data:
173
+ name = data.get('name', name)
174
+ else:
175
+ # if data is None, means this is not an embeding, just a preview image
176
+ return
177
+ elif ext in ['.BIN', '.PT']:
178
+ data = torch.load(path, map_location="cpu")
179
+ elif ext in ['.SAFETENSORS']:
180
+ data = safetensors.torch.load_file(path, device="cpu")
181
+ else:
182
+ return
183
+
184
+ # textual inversion embeddings
185
+ if 'string_to_param' in data:
186
+ param_dict = data['string_to_param']
187
+ param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
188
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
189
+ emb = next(iter(param_dict.items()))[1]
190
+ # diffuser concepts
191
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
192
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
193
+
194
+ emb = next(iter(data.values()))
195
+ if len(emb.shape) == 1:
196
+ emb = emb.unsqueeze(0)
197
+ else:
198
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
199
+
200
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
201
+ embedding = Embedding(vec, name)
202
+ embedding.step = data.get('step', None)
203
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
204
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
205
+ embedding.vectors = vec.shape[0]
206
+ embedding.shape = vec.shape[-1]
207
+ embedding.filename = path
208
+ embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
209
+
210
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
211
+ self.register_embedding(embedding, shared.sd_model)
212
+ else:
213
+ self.skipped_embeddings[name] = embedding
214
+
215
+ def load_from_dir(self, embdir):
216
+ if not os.path.isdir(embdir.path):
217
+ return
218
+
219
+ for root, _, fns in os.walk(embdir.path, followlinks=True):
220
+ for fn in fns:
221
+ try:
222
+ fullfn = os.path.join(root, fn)
223
+
224
+ if os.stat(fullfn).st_size == 0:
225
+ continue
226
+
227
+ self.load_from_file(fullfn, fn)
228
+ except Exception:
229
+ errors.report(f"Error loading embedding {fn}", exc_info=True)
230
+ continue
231
+
232
+ def load_textual_inversion_embeddings(self, force_reload=False):
233
+ if not force_reload:
234
+ need_reload = False
235
+ for embdir in self.embedding_dirs.values():
236
+ if embdir.has_changed():
237
+ need_reload = True
238
+ break
239
+
240
+ if not need_reload:
241
+ return
242
+
243
+ self.ids_lookup.clear()
244
+ self.word_embeddings.clear()
245
+ self.skipped_embeddings.clear()
246
+ self.expected_shape = self.get_expected_shape()
247
+
248
+ for embdir in self.embedding_dirs.values():
249
+ self.load_from_dir(embdir)
250
+ embdir.update()
251
+
252
+ # re-sort word_embeddings because load_from_dir may not load in alphabetic order.
253
+ # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.
254
+ sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
255
+ self.word_embeddings.clear()
256
+ self.word_embeddings.update(sorted_word_embeddings)
257
+
258
+ displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
259
+ if shared.opts.textual_inversion_print_at_load and self.previously_displayed_embeddings != displayed_embeddings:
260
+ self.previously_displayed_embeddings = displayed_embeddings
261
+ print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
262
+ if self.skipped_embeddings:
263
+ print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
264
+
265
+ def find_embedding_at_position(self, tokens, offset):
266
+ token = tokens[offset]
267
+ possible_matches = self.ids_lookup.get(token, None)
268
+
269
+ if possible_matches is None:
270
+ return None, None
271
+
272
+ for ids, embedding in possible_matches:
273
+ if tokens[offset:offset + len(ids)] == ids:
274
+ return embedding, len(ids)
275
+
276
+ return None, None
277
+
278
+
279
+ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
280
+ cond_model = shared.sd_model.cond_stage_model
281
+
282
+ with devices.autocast():
283
+ cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
284
+
285
+ #cond_model expects at least some text, so we provide '*' as backup.
286
+ embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
287
+ vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
288
+
289
+ #Only copy if we provided an init_text, otherwise keep vectors as zeros
290
+ if init_text:
291
+ for i in range(num_vectors_per_token):
292
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
293
+
294
+ # Remove illegal characters from name.
295
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
296
+ fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
297
+ if not overwrite_old:
298
+ assert not os.path.exists(fn), f"file {fn} already exists"
299
+
300
+ embedding = Embedding(vec, name)
301
+ embedding.step = 0
302
+ embedding.save(fn)
303
+
304
+ return fn
305
+
306
+
307
+ def write_loss(log_directory, filename, step, epoch_len, values):
308
+ if shared.opts.training_write_csv_every == 0:
309
+ return
310
+
311
+ if step % shared.opts.training_write_csv_every != 0:
312
+ return
313
+ write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
314
+
315
+ with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
316
+ csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
317
+
318
+ if write_csv_header:
319
+ csv_writer.writeheader()
320
+
321
+ epoch = (step - 1) // epoch_len
322
+ epoch_step = (step - 1) % epoch_len
323
+
324
+ csv_writer.writerow({
325
+ "step": step,
326
+ "epoch": epoch,
327
+ "epoch_step": epoch_step,
328
+ **values,
329
+ })
330
+
331
+ def tensorboard_setup(log_directory):
332
+ os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
333
+ return SummaryWriter(
334
+ log_dir=os.path.join(log_directory, "tensorboard"),
335
+ flush_secs=shared.opts.training_tensorboard_flush_every)
336
+
337
+ def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
338
+ tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
339
+ tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
340
+ tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
341
+ tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
342
+
343
+ def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
344
+ tensorboard_writer.add_scalar(tag=tag,
345
+ scalar_value=value, global_step=step)
346
+
347
+ def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
348
+ # Convert a pil image to a torch tensor
349
+ img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
350
+ img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
351
+ len(pil_image.getbands()))
352
+ img_tensor = img_tensor.permute((2, 0, 1))
353
+
354
+ tensorboard_writer.add_image(tag, img_tensor, global_step=step)
355
+
356
+ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
357
+ assert model_name, f"{name} not selected"
358
+ assert learn_rate, "Learning rate is empty or 0"
359
+ assert isinstance(batch_size, int), "Batch size must be integer"
360
+ assert batch_size > 0, "Batch size must be positive"
361
+ assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
362
+ assert gradient_step > 0, "Gradient accumulation step must be positive"
363
+ assert data_root, "Dataset directory is empty"
364
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
365
+ assert os.listdir(data_root), "Dataset directory is empty"
366
+ assert template_filename, "Prompt template file not selected"
367
+ assert template_file, f"Prompt template file {template_filename} not found"
368
+ assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
369
+ assert steps, "Max steps is empty or 0"
370
+ assert isinstance(steps, int), "Max steps must be integer"
371
+ assert steps > 0, "Max steps must be positive"
372
+ assert isinstance(save_model_every, int), "Save {name} must be integer"
373
+ assert save_model_every >= 0, "Save {name} must be positive or 0"
374
+ assert isinstance(create_image_every, int), "Create image must be integer"
375
+ assert create_image_every >= 0, "Create image must be positive or 0"
376
+ if save_model_every or create_image_every:
377
+ assert log_directory, "Log directory is empty"
378
+
379
+
380
+ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
381
+ save_embedding_every = save_embedding_every or 0
382
+ create_image_every = create_image_every or 0
383
+ template_file = textual_inversion_templates.get(template_filename, None)
384
+ validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
385
+ template_file = template_file.path
386
+
387
+ shared.state.job = "train-embedding"
388
+ shared.state.textinfo = "Initializing textual inversion training..."
389
+ shared.state.job_count = steps
390
+
391
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
392
+
393
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
394
+ unload = shared.opts.unload_models_when_training
395
+
396
+ if save_embedding_every > 0:
397
+ embedding_dir = os.path.join(log_directory, "embeddings")
398
+ os.makedirs(embedding_dir, exist_ok=True)
399
+ else:
400
+ embedding_dir = None
401
+
402
+ if create_image_every > 0:
403
+ images_dir = os.path.join(log_directory, "images")
404
+ os.makedirs(images_dir, exist_ok=True)
405
+ else:
406
+ images_dir = None
407
+
408
+ if create_image_every > 0 and save_image_with_stored_embedding:
409
+ images_embeds_dir = os.path.join(log_directory, "image_embeddings")
410
+ os.makedirs(images_embeds_dir, exist_ok=True)
411
+ else:
412
+ images_embeds_dir = None
413
+
414
+ hijack = sd_hijack.model_hijack
415
+
416
+ embedding = hijack.embedding_db.word_embeddings[embedding_name]
417
+ checkpoint = sd_models.select_checkpoint()
418
+
419
+ initial_step = embedding.step or 0
420
+ if initial_step >= steps:
421
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
422
+ return embedding, filename
423
+
424
+ scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
425
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
426
+ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
427
+ None
428
+ if clip_grad:
429
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
430
+ # dataset loading may take a while, so input validations and early returns should be done before this
431
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
432
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
433
+
434
+ if shared.opts.training_enable_tensorboard:
435
+ tensorboard_writer = tensorboard_setup(log_directory)
436
+
437
+ pin_memory = shared.opts.pin_memory
438
+
439
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
440
+
441
+ if shared.opts.save_training_settings_to_txt:
442
+ save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
443
+
444
+ latent_sampling_method = ds.latent_sampling_method
445
+
446
+ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
447
+
448
+ if unload:
449
+ shared.parallel_processing_allowed = False
450
+ shared.sd_model.first_stage_model.to(devices.cpu)
451
+
452
+ embedding.vec.requires_grad = True
453
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
454
+ if shared.opts.save_optimizer_state:
455
+ optimizer_state_dict = None
456
+ if os.path.exists(f"{filename}.optim"):
457
+ optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
458
+ if embedding.checksum() == optimizer_saved_dict.get('hash', None):
459
+ optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
460
+
461
+ if optimizer_state_dict is not None:
462
+ optimizer.load_state_dict(optimizer_state_dict)
463
+ print("Loaded existing optimizer from checkpoint")
464
+ else:
465
+ print("No saved optimizer exists in checkpoint")
466
+
467
+ scaler = torch.cuda.amp.GradScaler()
468
+
469
+ batch_size = ds.batch_size
470
+ gradient_step = ds.gradient_step
471
+ # n steps = batch_size * gradient_step * n image processed
472
+ steps_per_epoch = len(ds) // batch_size // gradient_step
473
+ max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
474
+ loss_step = 0
475
+ _loss_step = 0 #internal
476
+
477
+ last_saved_file = "<none>"
478
+ last_saved_image = "<none>"
479
+ forced_filename = "<none>"
480
+ embedding_yet_to_be_embedded = False
481
+
482
+ is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
483
+ img_c = None
484
+
485
+ pbar = tqdm.tqdm(total=steps - initial_step)
486
+ try:
487
+ sd_hijack_checkpoint.add()
488
+
489
+ for _ in range((steps-initial_step) * gradient_step):
490
+ if scheduler.finished:
491
+ break
492
+ if shared.state.interrupted:
493
+ break
494
+ for j, batch in enumerate(dl):
495
+ # works as a drop_last=True for gradient accumulation
496
+ if j == max_steps_per_epoch:
497
+ break
498
+ scheduler.apply(optimizer, embedding.step)
499
+ if scheduler.finished:
500
+ break
501
+ if shared.state.interrupted:
502
+ break
503
+
504
+ if clip_grad:
505
+ clip_grad_sched.step(embedding.step)
506
+
507
+ with devices.autocast():
508
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
509
+ if use_weight:
510
+ w = batch.weight.to(devices.device, non_blocking=pin_memory)
511
+ c = shared.sd_model.cond_stage_model(batch.cond_text)
512
+
513
+ if is_training_inpainting_model:
514
+ if img_c is None:
515
+ img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
516
+
517
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
518
+ else:
519
+ cond = c
520
+
521
+ if use_weight:
522
+ loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step
523
+ del w
524
+ else:
525
+ loss = shared.sd_model.forward(x, cond)[0] / gradient_step
526
+ del x
527
+
528
+ _loss_step += loss.item()
529
+ scaler.scale(loss).backward()
530
+
531
+ # go back until we reach gradient accumulation steps
532
+ if (j + 1) % gradient_step != 0:
533
+ continue
534
+
535
+ if clip_grad:
536
+ clip_grad(embedding.vec, clip_grad_sched.learn_rate)
537
+
538
+ scaler.step(optimizer)
539
+ scaler.update()
540
+ embedding.step += 1
541
+ pbar.update()
542
+ optimizer.zero_grad(set_to_none=True)
543
+ loss_step = _loss_step
544
+ _loss_step = 0
545
+
546
+ steps_done = embedding.step + 1
547
+
548
+ epoch_num = embedding.step // steps_per_epoch
549
+ epoch_step = embedding.step % steps_per_epoch
550
+
551
+ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
552
+ pbar.set_description(description)
553
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
554
+ # Before saving, change name to match current checkpoint.
555
+ embedding_name_every = f'{embedding_name}-{steps_done}'
556
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
557
+ save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
558
+ embedding_yet_to_be_embedded = True
559
+
560
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
561
+ "loss": f"{loss_step:.7f}",
562
+ "learn_rate": scheduler.learn_rate
563
+ })
564
+
565
+ if images_dir is not None and steps_done % create_image_every == 0:
566
+ forced_filename = f'{embedding_name}-{steps_done}'
567
+ last_saved_image = os.path.join(images_dir, forced_filename)
568
+
569
+ shared.sd_model.first_stage_model.to(devices.device)
570
+
571
+ p = processing.StableDiffusionProcessingTxt2Img(
572
+ sd_model=shared.sd_model,
573
+ do_not_save_grid=True,
574
+ do_not_save_samples=True,
575
+ do_not_reload_embeddings=True,
576
+ )
577
+
578
+ if preview_from_txt2img:
579
+ p.prompt = preview_prompt
580
+ p.negative_prompt = preview_negative_prompt
581
+ p.steps = preview_steps
582
+ p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
583
+ p.cfg_scale = preview_cfg_scale
584
+ p.seed = preview_seed
585
+ p.width = preview_width
586
+ p.height = preview_height
587
+ else:
588
+ p.prompt = batch.cond_text[0]
589
+ p.steps = 20
590
+ p.width = training_width
591
+ p.height = training_height
592
+
593
+ preview_text = p.prompt
594
+
595
+ with closing(p):
596
+ processed = processing.process_images(p)
597
+ image = processed.images[0] if len(processed.images) > 0 else None
598
+
599
+ if unload:
600
+ shared.sd_model.first_stage_model.to(devices.cpu)
601
+
602
+ if image is not None:
603
+ shared.state.assign_current_image(image)
604
+
605
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
606
+ last_saved_image += f", prompt: {preview_text}"
607
+
608
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
609
+ tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
610
+
611
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
612
+
613
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
614
+
615
+ info = PngImagePlugin.PngInfo()
616
+ data = torch.load(last_saved_file)
617
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
618
+
619
+ title = f"<{data.get('name', '???')}>"
620
+
621
+ try:
622
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
623
+ except Exception:
624
+ vectorSize = '?'
625
+
626
+ checkpoint = sd_models.select_checkpoint()
627
+ footer_left = checkpoint.model_name
628
+ footer_mid = f'[{checkpoint.shorthash}]'
629
+ footer_right = f'{vectorSize}v {steps_done}s'
630
+
631
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
632
+ captioned_image = insert_image_data_embed(captioned_image, data)
633
+
634
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
635
+ embedding_yet_to_be_embedded = False
636
+
637
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
638
+ last_saved_image += f", prompt: {preview_text}"
639
+
640
+ shared.state.job_no = embedding.step
641
+
642
+ shared.state.textinfo = f"""
643
+ <p>
644
+ Loss: {loss_step:.7f}<br/>
645
+ Step: {steps_done}<br/>
646
+ Last prompt: {html.escape(batch.cond_text[0])}<br/>
647
+ Last saved embedding: {html.escape(last_saved_file)}<br/>
648
+ Last saved image: {html.escape(last_saved_image)}<br/>
649
+ </p>
650
+ """
651
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
652
+ save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
653
+ except Exception:
654
+ errors.report("Error training embedding", exc_info=True)
655
+ finally:
656
+ pbar.leave = False
657
+ pbar.close()
658
+ shared.sd_model.first_stage_model.to(devices.device)
659
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
660
+ sd_hijack_checkpoint.remove()
661
+
662
+ return embedding, filename
663
+
664
+
665
+ def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
666
+ old_embedding_name = embedding.name
667
+ old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
668
+ old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
669
+ old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
670
+ try:
671
+ embedding.sd_checkpoint = checkpoint.shorthash
672
+ embedding.sd_checkpoint_name = checkpoint.model_name
673
+ if remove_cached_checksum:
674
+ embedding.cached_checksum = None
675
+ embedding.name = embedding_name
676
+ embedding.optimizer_state_dict = optimizer.state_dict()
677
+ embedding.save(filename)
678
+ except:
679
+ embedding.sd_checkpoint = old_sd_checkpoint
680
+ embedding.sd_checkpoint_name = old_sd_checkpoint_name
681
+ embedding.name = old_embedding_name
682
+ embedding.cached_checksum = old_cached_checksum
683
+ raise
modules/textual_inversion/ui.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+
3
+ import gradio as gr
4
+
5
+ import modules.textual_inversion.textual_inversion
6
+ import modules.textual_inversion.preprocess
7
+ from modules import sd_hijack, shared
8
+
9
+
10
+ def create_embedding(name, initialization_text, nvpt, overwrite_old):
11
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
12
+
13
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
14
+
15
+ return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
16
+
17
+
18
+ def preprocess(*args):
19
+ modules.textual_inversion.preprocess.preprocess(*args)
20
+
21
+ return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
22
+
23
+
24
+ def train_embedding(*args):
25
+
26
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
27
+
28
+ apply_optimizations = shared.opts.training_xattention_optimizations
29
+ try:
30
+ if not apply_optimizations:
31
+ sd_hijack.undo_optimizations()
32
+
33
+ embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
34
+
35
+ res = f"""
36
+ Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
37
+ Embedding saved to {html.escape(filename)}
38
+ """
39
+ return res, ""
40
+ except Exception:
41
+ raise
42
+ finally:
43
+ if not apply_optimizations:
44
+ sd_hijack.apply_optimizations()
45
+
modules/timer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import argparse
3
+
4
+
5
+ class TimerSubcategory:
6
+ def __init__(self, timer, category):
7
+ self.timer = timer
8
+ self.category = category
9
+ self.start = None
10
+ self.original_base_category = timer.base_category
11
+
12
+ def __enter__(self):
13
+ self.start = time.time()
14
+ self.timer.base_category = self.original_base_category + self.category + "/"
15
+ self.timer.subcategory_level += 1
16
+
17
+ if self.timer.print_log:
18
+ print(f"{' ' * self.timer.subcategory_level}{self.category}:")
19
+
20
+ def __exit__(self, exc_type, exc_val, exc_tb):
21
+ elapsed_for_subcategroy = time.time() - self.start
22
+ self.timer.base_category = self.original_base_category
23
+ self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
24
+ self.timer.subcategory_level -= 1
25
+ self.timer.record(self.category, disable_log=True)
26
+
27
+
28
+ class Timer:
29
+ def __init__(self, print_log=False):
30
+ self.start = time.time()
31
+ self.records = {}
32
+ self.total = 0
33
+ self.base_category = ''
34
+ self.print_log = print_log
35
+ self.subcategory_level = 0
36
+
37
+ def elapsed(self):
38
+ end = time.time()
39
+ res = end - self.start
40
+ self.start = end
41
+ return res
42
+
43
+ def add_time_to_record(self, category, amount):
44
+ if category not in self.records:
45
+ self.records[category] = 0
46
+
47
+ self.records[category] += amount
48
+
49
+ def record(self, category, extra_time=0, disable_log=False):
50
+ e = self.elapsed()
51
+
52
+ self.add_time_to_record(self.base_category + category, e + extra_time)
53
+
54
+ self.total += e + extra_time
55
+
56
+ if self.print_log and not disable_log:
57
+ print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s")
58
+
59
+ def subcategory(self, name):
60
+ self.elapsed()
61
+
62
+ subcat = TimerSubcategory(self, name)
63
+ return subcat
64
+
65
+ def summary(self):
66
+ res = f"{self.total:.1f}s"
67
+
68
+ additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category]
69
+ if not additions:
70
+ return res
71
+
72
+ res += " ("
73
+ res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
74
+ res += ")"
75
+
76
+ return res
77
+
78
+ def dump(self):
79
+ return {'total': self.total, 'records': self.records}
80
+
81
+ def reset(self):
82
+ self.__init__()
83
+
84
+
85
+ parser = argparse.ArgumentParser(add_help=False)
86
+ parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup")
87
+ args = parser.parse_known_args()[0]
88
+
89
+ startup_timer = Timer(print_log=args.log_startup)
90
+
91
+ startup_record = None
modules/txt2img.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import closing
2
+
3
+ import modules.scripts
4
+ from modules import sd_samplers, processing
5
+ from modules.generation_parameters_copypaste import create_override_settings_dict
6
+ from modules.shared import opts, cmd_opts
7
+ import modules.shared as shared
8
+ from modules.ui import plaintext_to_html
9
+ import gradio as gr
10
+
11
+
12
+ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
13
+ override_settings = create_override_settings_dict(override_settings_texts)
14
+
15
+ p = processing.StableDiffusionProcessingTxt2Img(
16
+ sd_model=shared.sd_model,
17
+ outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
18
+ outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
19
+ prompt=prompt,
20
+ styles=prompt_styles,
21
+ negative_prompt=negative_prompt,
22
+ seed=seed,
23
+ subseed=subseed,
24
+ subseed_strength=subseed_strength,
25
+ seed_resize_from_h=seed_resize_from_h,
26
+ seed_resize_from_w=seed_resize_from_w,
27
+ seed_enable_extras=seed_enable_extras,
28
+ sampler_name=sd_samplers.samplers[sampler_index].name,
29
+ batch_size=batch_size,
30
+ n_iter=n_iter,
31
+ steps=steps,
32
+ cfg_scale=cfg_scale,
33
+ width=width,
34
+ height=height,
35
+ restore_faces=restore_faces,
36
+ tiling=tiling,
37
+ enable_hr=enable_hr,
38
+ denoising_strength=denoising_strength if enable_hr else None,
39
+ hr_scale=hr_scale,
40
+ hr_upscaler=hr_upscaler,
41
+ hr_second_pass_steps=hr_second_pass_steps,
42
+ hr_resize_x=hr_resize_x,
43
+ hr_resize_y=hr_resize_y,
44
+ hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
45
+ hr_prompt=hr_prompt,
46
+ hr_negative_prompt=hr_negative_prompt,
47
+ override_settings=override_settings,
48
+ )
49
+
50
+ p.scripts = modules.scripts.scripts_txt2img
51
+ p.script_args = args
52
+
53
+ p.user = request.username
54
+
55
+ if cmd_opts.enable_console_prompts:
56
+ print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
57
+
58
+ with closing(p):
59
+ processed = modules.scripts.scripts_txt2img.run(p, *args)
60
+
61
+ if processed is None:
62
+ processed = processing.process_images(p)
63
+
64
+ shared.total_tqdm.clear()
65
+
66
+ generation_info_js = processed.js()
67
+ if opts.samples_log_stdout:
68
+ print(generation_info_js)
69
+
70
+ if opts.do_not_show_images:
71
+ processed.images = []
72
+
73
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
modules/ui.py ADDED
The diff for this file is too large to render. See raw diff
 
modules/ui_common.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import html
3
+ import os
4
+ import platform
5
+ import sys
6
+
7
+ import gradio as gr
8
+ import subprocess as sp
9
+
10
+ from modules import call_queue, shared
11
+ from modules.generation_parameters_copypaste import image_from_url_text
12
+ import modules.images
13
+ from modules.ui_components import ToolButton
14
+
15
+
16
+ folder_symbol = '\U0001f4c2' # 📂
17
+ refresh_symbol = '\U0001f504' # 🔄
18
+
19
+
20
+ def update_generation_info(generation_info, html_info, img_index):
21
+ try:
22
+ generation_info = json.loads(generation_info)
23
+ if img_index < 0 or img_index >= len(generation_info["infotexts"]):
24
+ return html_info, gr.update()
25
+ return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
26
+ except Exception:
27
+ pass
28
+ # if the json parse or anything else fails, just return the old html_info
29
+ return html_info, gr.update()
30
+
31
+
32
+ def plaintext_to_html(text, classname=None):
33
+ content = "<br>\n".join(html.escape(x) for x in text.split('\n'))
34
+
35
+ return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
36
+
37
+
38
+ def save_files(js_data, images, do_make_zip, index):
39
+ import csv
40
+ filenames = []
41
+ fullfns = []
42
+
43
+ #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
44
+ class MyObject:
45
+ def __init__(self, d=None):
46
+ if d is not None:
47
+ for key, value in d.items():
48
+ setattr(self, key, value)
49
+
50
+ data = json.loads(js_data)
51
+
52
+ p = MyObject(data)
53
+ path = shared.opts.outdir_save
54
+ save_to_dirs = shared.opts.use_save_to_dirs_for_ui
55
+ extension: str = shared.opts.samples_format
56
+ start_index = 0
57
+ only_one = False
58
+
59
+ if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
60
+ only_one = True
61
+ images = [images[index]]
62
+ start_index = index
63
+
64
+ os.makedirs(shared.opts.outdir_save, exist_ok=True)
65
+
66
+ with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
67
+ at_start = file.tell() == 0
68
+ writer = csv.writer(file)
69
+ if at_start:
70
+ writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
71
+
72
+ for image_index, filedata in enumerate(images, start_index):
73
+ image = image_from_url_text(filedata)
74
+
75
+ is_grid = image_index < p.index_of_first_image
76
+ i = 0 if is_grid else (image_index - p.index_of_first_image)
77
+
78
+ p.batch_index = image_index-1
79
+ fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
80
+
81
+ filename = os.path.relpath(fullfn, path)
82
+ filenames.append(filename)
83
+ fullfns.append(fullfn)
84
+ if txt_fullfn:
85
+ filenames.append(os.path.basename(txt_fullfn))
86
+ fullfns.append(txt_fullfn)
87
+
88
+ writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
89
+
90
+ # Make Zip
91
+ if do_make_zip:
92
+ zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
93
+ namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
94
+ zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
95
+ zip_filepath = os.path.join(path, f"{zip_filename}.zip")
96
+
97
+ from zipfile import ZipFile
98
+ with ZipFile(zip_filepath, "w") as zip_file:
99
+ for i in range(len(fullfns)):
100
+ with open(fullfns[i], mode="rb") as f:
101
+ zip_file.writestr(filenames[i], f.read())
102
+ fullfns.insert(0, zip_filepath)
103
+
104
+ return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
105
+
106
+
107
+ def create_output_panel(tabname, outdir):
108
+ from modules import shared
109
+ import modules.generation_parameters_copypaste as parameters_copypaste
110
+
111
+ def open_folder(f):
112
+ if not os.path.exists(f):
113
+ print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
114
+ return
115
+ elif not os.path.isdir(f):
116
+ print(f"""
117
+ WARNING
118
+ An open_folder request was made with an argument that is not a folder.
119
+ This could be an error or a malicious attempt to run code on your computer.
120
+ Requested path was: {f}
121
+ """, file=sys.stderr)
122
+ return
123
+
124
+ if not shared.cmd_opts.hide_ui_dir_config:
125
+ path = os.path.normpath(f)
126
+ if platform.system() == "Windows":
127
+ os.startfile(path)
128
+ elif platform.system() == "Darwin":
129
+ sp.Popen(["open", path])
130
+ elif "microsoft-standard-WSL2" in platform.uname().release:
131
+ sp.Popen(["wsl-open", path])
132
+ else:
133
+ sp.Popen(["xdg-open", path])
134
+
135
+ with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
136
+
137
+ with gr.Group(elem_id=f"{tabname}_gallery_container"):
138
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
139
+
140
+ generation_info = None
141
+ with gr.Column():
142
+ with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
143
+ open_folder_button = gr.Button(folder_symbol, elem_id=f'open_folder_{tabname}', visible=not shared.cmd_opts.hide_ui_dir_config)
144
+
145
+ if tabname != "extras":
146
+ save = gr.Button('Save', elem_id=f'save_{tabname}')
147
+ save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
148
+
149
+ buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
150
+
151
+ open_folder_button.click(
152
+ fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
153
+ inputs=[],
154
+ outputs=[],
155
+ )
156
+
157
+ if tabname != "extras":
158
+ with gr.Row():
159
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
160
+
161
+ with gr.Accordion("Generation Info", open=False):
162
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
163
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
164
+
165
+ generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
166
+ if tabname == 'txt2img' or tabname == 'img2img':
167
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
168
+ generation_info_button.click(
169
+ fn=update_generation_info,
170
+ _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
171
+ inputs=[generation_info, html_info, html_info],
172
+ outputs=[html_info, html_info],
173
+ show_progress=False,
174
+ )
175
+
176
+ save.click(
177
+ fn=call_queue.wrap_gradio_call(save_files),
178
+ _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
179
+ inputs=[
180
+ generation_info,
181
+ result_gallery,
182
+ html_info,
183
+ html_info,
184
+ ],
185
+ outputs=[
186
+ download_files,
187
+ html_log,
188
+ ],
189
+ show_progress=False,
190
+ )
191
+
192
+ save_zip.click(
193
+ fn=call_queue.wrap_gradio_call(save_files),
194
+ _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
195
+ inputs=[
196
+ generation_info,
197
+ result_gallery,
198
+ html_info,
199
+ html_info,
200
+ ],
201
+ outputs=[
202
+ download_files,
203
+ html_log,
204
+ ]
205
+ )
206
+
207
+ else:
208
+ html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
209
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
210
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
211
+
212
+ paste_field_names = []
213
+ if tabname == "txt2img":
214
+ paste_field_names = modules.scripts.scripts_txt2img.paste_field_names
215
+ elif tabname == "img2img":
216
+ paste_field_names = modules.scripts.scripts_img2img.paste_field_names
217
+
218
+ for paste_tabname, paste_button in buttons.items():
219
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
220
+ paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
221
+ paste_field_names=paste_field_names
222
+ ))
223
+
224
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
225
+
226
+
227
+ def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
228
+ def refresh():
229
+ refresh_method()
230
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
231
+
232
+ for k, v in args.items():
233
+ setattr(refresh_component, k, v)
234
+
235
+ return gr.update(**(args or {}))
236
+
237
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
238
+ refresh_button.click(
239
+ fn=refresh,
240
+ inputs=[],
241
+ outputs=[refresh_component]
242
+ )
243
+ return refresh_button
244
+
modules/ui_components.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ class FormComponent:
5
+ def get_expected_parent(self):
6
+ return gr.components.Form
7
+
8
+
9
+ gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
10
+
11
+
12
+ class ToolButton(FormComponent, gr.Button):
13
+ """Small button with single emoji as text, fits inside gradio forms"""
14
+
15
+ def __init__(self, *args, **kwargs):
16
+ classes = kwargs.pop("elem_classes", [])
17
+ super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
18
+
19
+ def get_block_name(self):
20
+ return "button"
21
+
22
+
23
+ class FormRow(FormComponent, gr.Row):
24
+ """Same as gr.Row but fits inside gradio forms"""
25
+
26
+ def get_block_name(self):
27
+ return "row"
28
+
29
+
30
+ class FormColumn(FormComponent, gr.Column):
31
+ """Same as gr.Column but fits inside gradio forms"""
32
+
33
+ def get_block_name(self):
34
+ return "column"
35
+
36
+
37
+ class FormGroup(FormComponent, gr.Group):
38
+ """Same as gr.Row but fits inside gradio forms"""
39
+
40
+ def get_block_name(self):
41
+ return "group"
42
+
43
+
44
+ class FormHTML(FormComponent, gr.HTML):
45
+ """Same as gr.HTML but fits inside gradio forms"""
46
+
47
+ def get_block_name(self):
48
+ return "html"
49
+
50
+
51
+ class FormColorPicker(FormComponent, gr.ColorPicker):
52
+ """Same as gr.ColorPicker but fits inside gradio forms"""
53
+
54
+ def get_block_name(self):
55
+ return "colorpicker"
56
+
57
+
58
+ class DropdownMulti(FormComponent, gr.Dropdown):
59
+ """Same as gr.Dropdown but always multiselect"""
60
+ def __init__(self, **kwargs):
61
+ super().__init__(multiselect=True, **kwargs)
62
+
63
+ def get_block_name(self):
64
+ return "dropdown"
65
+
66
+
67
+ class DropdownEditable(FormComponent, gr.Dropdown):
68
+ """Same as gr.Dropdown but allows editing value"""
69
+ def __init__(self, **kwargs):
70
+ super().__init__(allow_custom_value=True, **kwargs)
71
+
72
+ def get_block_name(self):
73
+ return "dropdown"
74
+
modules/ui_extensions.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import threading
4
+ import time
5
+ from datetime import datetime
6
+
7
+ import git
8
+
9
+ import gradio as gr
10
+ import html
11
+ import shutil
12
+ import errno
13
+
14
+ from modules import extensions, shared, paths, config_states, errors, restart
15
+ from modules.paths_internal import config_states_dir
16
+ from modules.call_queue import wrap_gradio_gpu_call
17
+
18
+ available_extensions = {"extensions": []}
19
+ STYLE_PRIMARY = ' style="color: var(--primary-400)"'
20
+
21
+
22
+ def check_access():
23
+ assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
24
+
25
+
26
+ def apply_and_restart(disable_list, update_list, disable_all):
27
+ check_access()
28
+
29
+ disabled = json.loads(disable_list)
30
+ assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
31
+
32
+ update = json.loads(update_list)
33
+ assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
34
+
35
+ if update:
36
+ save_config_state("Backup (pre-update)")
37
+
38
+ update = set(update)
39
+
40
+ for ext in extensions.extensions:
41
+ if ext.name not in update:
42
+ continue
43
+
44
+ try:
45
+ ext.fetch_and_reset_hard()
46
+ except Exception:
47
+ errors.report(f"Error getting updates for {ext.name}", exc_info=True)
48
+
49
+ shared.opts.disabled_extensions = disabled
50
+ shared.opts.disable_all_extensions = disable_all
51
+ shared.opts.save(shared.config_filename)
52
+
53
+ if restart.is_restartable():
54
+ restart.restart_program()
55
+ else:
56
+ restart.stop_program()
57
+
58
+
59
+ def save_config_state(name):
60
+ current_config_state = config_states.get_config()
61
+ if not name:
62
+ name = "Config"
63
+ current_config_state["name"] = name
64
+ timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S')
65
+ filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
66
+ print(f"Saving backup of webui/extension state to {filename}.")
67
+ with open(filename, "w", encoding="utf-8") as f:
68
+ json.dump(current_config_state, f)
69
+ config_states.list_config_states()
70
+ new_value = next(iter(config_states.all_config_states.keys()), "Current")
71
+ new_choices = ["Current"] + list(config_states.all_config_states.keys())
72
+ return gr.Dropdown.update(value=new_value, choices=new_choices), f"<span>Saved current webui/extension state to \"{filename}\"</span>"
73
+
74
+
75
+ def restore_config_state(confirmed, config_state_name, restore_type):
76
+ if config_state_name == "Current":
77
+ return "<span>Select a config to restore from.</span>"
78
+ if not confirmed:
79
+ return "<span>Cancelled.</span>"
80
+
81
+ check_access()
82
+
83
+ config_state = config_states.all_config_states[config_state_name]
84
+
85
+ print(f"*** Restoring webui state from backup: {restore_type} ***")
86
+
87
+ if restore_type == "extensions" or restore_type == "both":
88
+ shared.opts.restore_config_state_file = config_state["filepath"]
89
+ shared.opts.save(shared.config_filename)
90
+
91
+ if restore_type == "webui" or restore_type == "both":
92
+ config_states.restore_webui_config(config_state)
93
+
94
+ shared.state.request_restart()
95
+
96
+ return ""
97
+
98
+
99
+ def check_updates(id_task, disable_list):
100
+ check_access()
101
+
102
+ disabled = json.loads(disable_list)
103
+ assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
104
+
105
+ exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
106
+ shared.state.job_count = len(exts)
107
+
108
+ for ext in exts:
109
+ shared.state.textinfo = ext.name
110
+
111
+ try:
112
+ ext.check_updates()
113
+ except FileNotFoundError as e:
114
+ if 'FETCH_HEAD' not in str(e):
115
+ raise
116
+ except Exception:
117
+ errors.report(f"Error checking updates for {ext.name}", exc_info=True)
118
+
119
+ shared.state.nextjob()
120
+
121
+ return extension_table(), ""
122
+
123
+
124
+ def make_commit_link(commit_hash, remote, text=None):
125
+ if text is None:
126
+ text = commit_hash[:8]
127
+ if remote.startswith("https://github.com/"):
128
+ if remote.endswith(".git"):
129
+ remote = remote[:-4]
130
+ href = remote + "/commit/" + commit_hash
131
+ return f'<a href="{href}" target="_blank">{text}</a>'
132
+ else:
133
+ return text
134
+
135
+
136
+ def extension_table():
137
+ code = f"""<!-- {time.time()} -->
138
+ <table id="extensions">
139
+ <thead>
140
+ <tr>
141
+ <th>
142
+ <input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
143
+ <abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
144
+ </th>
145
+ <th>URL</th>
146
+ <th>Branch</th>
147
+ <th>Version</th>
148
+ <th>Date</th>
149
+ <th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
150
+ </tr>
151
+ </thead>
152
+ <tbody>
153
+ """
154
+
155
+ for ext in extensions.extensions:
156
+ ext: extensions.Extension
157
+ ext.read_info_from_repo()
158
+
159
+ remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
160
+
161
+ if ext.can_update:
162
+ ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
163
+ else:
164
+ ext_status = ext.status
165
+
166
+ style = ""
167
+ if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
168
+ style = STYLE_PRIMARY
169
+
170
+ version_link = ext.version
171
+ if ext.commit_hash and ext.remote:
172
+ version_link = make_commit_link(ext.commit_hash, ext.remote, ext.version)
173
+
174
+ code += f"""
175
+ <tr>
176
+ <td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
177
+ <td>{remote}</td>
178
+ <td>{ext.branch}</td>
179
+ <td>{version_link}</td>
180
+ <td>{time.asctime(time.gmtime(ext.commit_date))}</td>
181
+ <td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
182
+ </tr>
183
+ """
184
+
185
+ code += """
186
+ </tbody>
187
+ </table>
188
+ """
189
+
190
+ return code
191
+
192
+
193
+ def update_config_states_table(state_name):
194
+ if state_name == "Current":
195
+ config_state = config_states.get_config()
196
+ else:
197
+ config_state = config_states.all_config_states[state_name]
198
+
199
+ config_name = config_state.get("name", "Config")
200
+ created_date = time.asctime(time.gmtime(config_state["created_at"]))
201
+ filepath = config_state.get("filepath", "<unknown>")
202
+
203
+ code = f"""<!-- {time.time()} -->"""
204
+
205
+ webui_remote = config_state["webui"]["remote"] or ""
206
+ webui_branch = config_state["webui"]["branch"]
207
+ webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
208
+ webui_commit_date = config_state["webui"]["commit_date"]
209
+ if webui_commit_date:
210
+ webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
211
+ else:
212
+ webui_commit_date = "<unknown>"
213
+
214
+ remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
215
+ commit_link = make_commit_link(webui_commit_hash, webui_remote)
216
+ date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
217
+
218
+ current_webui = config_states.get_webui_config()
219
+
220
+ style_remote = ""
221
+ style_branch = ""
222
+ style_commit = ""
223
+ if current_webui["remote"] != webui_remote:
224
+ style_remote = STYLE_PRIMARY
225
+ if current_webui["branch"] != webui_branch:
226
+ style_branch = STYLE_PRIMARY
227
+ if current_webui["commit_hash"] != webui_commit_hash:
228
+ style_commit = STYLE_PRIMARY
229
+
230
+ code += f"""<h2>Config Backup: {config_name}</h2>
231
+ <div><b>Filepath:</b> {filepath}</div>
232
+ <div><b>Created at:</b> {created_date}</div>"""
233
+
234
+ code += f"""<h2>WebUI State</h2>
235
+ <table id="config_state_webui">
236
+ <thead>
237
+ <tr>
238
+ <th>URL</th>
239
+ <th>Branch</th>
240
+ <th>Commit</th>
241
+ <th>Date</th>
242
+ </tr>
243
+ </thead>
244
+ <tbody>
245
+ <tr>
246
+ <td><label{style_remote}>{remote}</label></td>
247
+ <td><label{style_branch}>{webui_branch}</label></td>
248
+ <td><label{style_commit}>{commit_link}</label></td>
249
+ <td><label{style_commit}>{date_link}</label></td>
250
+ </tr>
251
+ </tbody>
252
+ </table>
253
+ """
254
+
255
+ code += """<h2>Extension State</h2>
256
+ <table id="config_state_extensions">
257
+ <thead>
258
+ <tr>
259
+ <th>Extension</th>
260
+ <th>URL</th>
261
+ <th>Branch</th>
262
+ <th>Commit</th>
263
+ <th>Date</th>
264
+ </tr>
265
+ </thead>
266
+ <tbody>
267
+ """
268
+
269
+ ext_map = {ext.name: ext for ext in extensions.extensions}
270
+
271
+ for ext_name, ext_conf in config_state["extensions"].items():
272
+ ext_remote = ext_conf["remote"] or ""
273
+ ext_branch = ext_conf["branch"] or "<unknown>"
274
+ ext_enabled = ext_conf["enabled"]
275
+ ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
276
+ ext_commit_date = ext_conf["commit_date"]
277
+ if ext_commit_date:
278
+ ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
279
+ else:
280
+ ext_commit_date = "<unknown>"
281
+
282
+ remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
283
+ commit_link = make_commit_link(ext_commit_hash, ext_remote)
284
+ date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
285
+
286
+ style_enabled = ""
287
+ style_remote = ""
288
+ style_branch = ""
289
+ style_commit = ""
290
+ if ext_name in ext_map:
291
+ current_ext = ext_map[ext_name]
292
+ current_ext.read_info_from_repo()
293
+ if current_ext.enabled != ext_enabled:
294
+ style_enabled = STYLE_PRIMARY
295
+ if current_ext.remote != ext_remote:
296
+ style_remote = STYLE_PRIMARY
297
+ if current_ext.branch != ext_branch:
298
+ style_branch = STYLE_PRIMARY
299
+ if current_ext.commit_hash != ext_commit_hash:
300
+ style_commit = STYLE_PRIMARY
301
+
302
+ code += f"""
303
+ <tr>
304
+ <td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
305
+ <td><label{style_remote}>{remote}</label></td>
306
+ <td><label{style_branch}>{ext_branch}</label></td>
307
+ <td><label{style_commit}>{commit_link}</label></td>
308
+ <td><label{style_commit}>{date_link}</label></td>
309
+ </tr>
310
+ """
311
+
312
+ code += """
313
+ </tbody>
314
+ </table>
315
+ """
316
+
317
+ return code
318
+
319
+
320
+ def normalize_git_url(url):
321
+ if url is None:
322
+ return ""
323
+
324
+ url = url.replace(".git", "")
325
+ return url
326
+
327
+
328
+ def install_extension_from_url(dirname, url, branch_name=None):
329
+ check_access()
330
+
331
+ if isinstance(dirname, str):
332
+ dirname = dirname.strip()
333
+ if isinstance(url, str):
334
+ url = url.strip()
335
+
336
+ assert url, 'No URL specified'
337
+
338
+ if dirname is None or dirname == "":
339
+ *parts, last_part = url.split('/')
340
+ last_part = normalize_git_url(last_part)
341
+
342
+ dirname = last_part
343
+
344
+ target_dir = os.path.join(extensions.extensions_dir, dirname)
345
+ assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
346
+
347
+ normalized_url = normalize_git_url(url)
348
+ if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url):
349
+ raise Exception(f'Extension with this URL is already installed: {url}')
350
+
351
+ tmpdir = os.path.join(paths.data_path, "tmp", dirname)
352
+
353
+ try:
354
+ shutil.rmtree(tmpdir, True)
355
+ if not branch_name:
356
+ # if no branch is specified, use the default branch
357
+ with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo:
358
+ repo.remote().fetch()
359
+ for submodule in repo.submodules:
360
+ submodule.update()
361
+ else:
362
+ with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo:
363
+ repo.remote().fetch()
364
+ for submodule in repo.submodules:
365
+ submodule.update()
366
+ try:
367
+ os.rename(tmpdir, target_dir)
368
+ except OSError as err:
369
+ if err.errno == errno.EXDEV:
370
+ # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
371
+ # Since we can't use a rename, do the slower but more versitile shutil.move()
372
+ shutil.move(tmpdir, target_dir)
373
+ else:
374
+ # Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
375
+ raise err
376
+
377
+ import launch
378
+ launch.run_extension_installer(target_dir)
379
+
380
+ extensions.list_extensions()
381
+ return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
382
+ finally:
383
+ shutil.rmtree(tmpdir, True)
384
+
385
+
386
+ def install_extension_from_index(url, hide_tags, sort_column, filter_text):
387
+ ext_table, message = install_extension_from_url(None, url)
388
+
389
+ code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
390
+
391
+ return code, ext_table, message, ''
392
+
393
+
394
+ def refresh_available_extensions(url, hide_tags, sort_column):
395
+ global available_extensions
396
+
397
+ import urllib.request
398
+ with urllib.request.urlopen(url) as response:
399
+ text = response.read()
400
+
401
+ available_extensions = json.loads(text)
402
+
403
+ code, tags = refresh_available_extensions_from_data(hide_tags, sort_column)
404
+
405
+ return url, code, gr.CheckboxGroup.update(choices=tags), '', ''
406
+
407
+
408
+ def refresh_available_extensions_for_tags(hide_tags, sort_column, filter_text):
409
+ code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
410
+
411
+ return code, ''
412
+
413
+
414
+ def search_extensions(filter_text, hide_tags, sort_column):
415
+ code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
416
+
417
+ return code, ''
418
+
419
+
420
+ sort_ordering = [
421
+ # (reverse, order_by_function)
422
+ (True, lambda x: x.get('added', 'z')),
423
+ (False, lambda x: x.get('added', 'z')),
424
+ (False, lambda x: x.get('name', 'z')),
425
+ (True, lambda x: x.get('name', 'z')),
426
+ (False, lambda x: 'z'),
427
+ (True, lambda x: x.get('commit_time', '')),
428
+ (True, lambda x: x.get('created_at', '')),
429
+ (True, lambda x: x.get('stars', 0)),
430
+ ]
431
+
432
+
433
+ def get_date(info: dict, key):
434
+ try:
435
+ return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d")
436
+ except (ValueError, TypeError):
437
+ return ''
438
+
439
+
440
+ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
441
+ extlist = available_extensions["extensions"]
442
+ installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
443
+
444
+ tags = available_extensions.get("tags", {})
445
+ tags_to_hide = set(hide_tags)
446
+ hidden = 0
447
+
448
+ code = f"""<!-- {time.time()} -->
449
+ <table id="available_extensions">
450
+ <thead>
451
+ <tr>
452
+ <th>Extension</th>
453
+ <th>Description</th>
454
+ <th>Action</th>
455
+ </tr>
456
+ </thead>
457
+ <tbody>
458
+ """
459
+
460
+ sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0]
461
+
462
+ for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
463
+ name = ext.get("name", "noname")
464
+ stars = int(ext.get("stars", 0))
465
+ added = ext.get('added', 'unknown')
466
+ update_time = get_date(ext, 'commit_time')
467
+ create_time = get_date(ext, 'created_at')
468
+ url = ext.get("url", None)
469
+ description = ext.get("description", "")
470
+ extension_tags = ext.get("tags", [])
471
+
472
+ if url is None:
473
+ continue
474
+
475
+ existing = installed_extension_urls.get(normalize_git_url(url), None)
476
+ extension_tags = extension_tags + ["installed"] if existing else extension_tags
477
+
478
+ if any(x for x in extension_tags if x in tags_to_hide):
479
+ hidden += 1
480
+ continue
481
+
482
+ if filter_text and filter_text.strip():
483
+ if filter_text.lower() not in html.escape(name).lower() and filter_text.lower() not in html.escape(description).lower():
484
+ hidden += 1
485
+ continue
486
+
487
+ install_code = f"""<button onclick="install_extension_from_index(this, '{html.escape(url)}')" {"disabled=disabled" if existing else ""} class="lg secondary gradio-button custom-button">{"Install" if not existing else "Installed"}</button>"""
488
+
489
+ tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
490
+
491
+ code += f"""
492
+ <tr>
493
+ <td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
494
+ <td>{html.escape(description)}<p class="info">
495
+ <span class="date_added">Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
496
+ <td>{install_code}</td>
497
+ </tr>
498
+
499
+ """
500
+
501
+ for tag in [x for x in extension_tags if x not in tags]:
502
+ tags[tag] = tag
503
+
504
+ code += """
505
+ </tbody>
506
+ </table>
507
+ """
508
+
509
+ if hidden > 0:
510
+ code += f"<p>Extension hidden: {hidden}</p>"
511
+
512
+ return code, list(tags)
513
+
514
+
515
+ def preload_extensions_git_metadata():
516
+ for extension in extensions.extensions:
517
+ extension.read_info_from_repo()
518
+
519
+
520
+ def create_ui():
521
+ import modules.ui
522
+
523
+ config_states.list_config_states()
524
+
525
+ threading.Thread(target=preload_extensions_git_metadata).start()
526
+
527
+ with gr.Blocks(analytics_enabled=False) as ui:
528
+ with gr.Tabs(elem_id="tabs_extensions"):
529
+ with gr.TabItem("Installed", id="installed"):
530
+
531
+ with gr.Row(elem_id="extensions_installed_top"):
532
+ apply_label = ("Apply and restart UI" if restart.is_restartable() else "Apply and quit")
533
+ apply = gr.Button(value=apply_label, variant="primary")
534
+ check = gr.Button(value="Check for updates")
535
+ extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
536
+ extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
537
+ extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
538
+
539
+ html = ""
540
+ if shared.opts.disable_all_extensions != "none":
541
+ html = """
542
+ <span style="color: var(--primary-400);">
543
+ "Disable all extensions" was set, change it to "none" to load all extensions again
544
+ </span>
545
+ """
546
+ info = gr.HTML(html)
547
+ extensions_table = gr.HTML('Loading...')
548
+ ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
549
+
550
+ apply.click(
551
+ fn=apply_and_restart,
552
+ _js="extensions_apply",
553
+ inputs=[extensions_disabled_list, extensions_update_list, extensions_disable_all],
554
+ outputs=[],
555
+ )
556
+
557
+ check.click(
558
+ fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]),
559
+ _js="extensions_check",
560
+ inputs=[info, extensions_disabled_list],
561
+ outputs=[extensions_table, info],
562
+ )
563
+
564
+ with gr.TabItem("Available", id="available"):
565
+ with gr.Row():
566
+ refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
567
+ extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
568
+ available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False)
569
+ extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
570
+ install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
571
+
572
+ with gr.Row():
573
+ hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
574
+ sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
575
+
576
+ with gr.Row():
577
+ search_extensions_text = gr.Text(label="Search").style(container=False)
578
+
579
+ install_result = gr.HTML()
580
+ available_extensions_table = gr.HTML()
581
+
582
+ refresh_available_extensions_button.click(
583
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
584
+ inputs=[available_extensions_index, hide_tags, sort_column],
585
+ outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
586
+ )
587
+
588
+ install_extension_button.click(
589
+ fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
590
+ inputs=[extension_to_install, hide_tags, sort_column, search_extensions_text],
591
+ outputs=[available_extensions_table, extensions_table, install_result],
592
+ )
593
+
594
+ search_extensions_text.change(
595
+ fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update()]),
596
+ inputs=[search_extensions_text, hide_tags, sort_column],
597
+ outputs=[available_extensions_table, install_result],
598
+ )
599
+
600
+ hide_tags.change(
601
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
602
+ inputs=[hide_tags, sort_column, search_extensions_text],
603
+ outputs=[available_extensions_table, install_result]
604
+ )
605
+
606
+ sort_column.change(
607
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
608
+ inputs=[hide_tags, sort_column, search_extensions_text],
609
+ outputs=[available_extensions_table, install_result]
610
+ )
611
+
612
+ with gr.TabItem("Install from URL", id="install_from_url"):
613
+ install_url = gr.Text(label="URL for extension's git repository")
614
+ install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch")
615
+ install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
616
+ install_button = gr.Button(value="Install", variant="primary")
617
+ install_result = gr.HTML(elem_id="extension_install_result")
618
+
619
+ install_button.click(
620
+ fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),
621
+ inputs=[install_dirname, install_url, install_branch],
622
+ outputs=[install_url, extensions_table, install_result],
623
+ )
624
+
625
+ with gr.TabItem("Backup/Restore"):
626
+ with gr.Row(elem_id="extensions_backup_top_row"):
627
+ config_states_list = gr.Dropdown(label="Saved Configs", elem_id="extension_backup_saved_configs", value="Current", choices=["Current"] + list(config_states.all_config_states.keys()))
628
+ modules.ui.create_refresh_button(config_states_list, config_states.list_config_states, lambda: {"choices": ["Current"] + list(config_states.all_config_states.keys())}, "refresh_config_states")
629
+ config_restore_type = gr.Radio(label="State to restore", choices=["extensions", "webui", "both"], value="extensions", elem_id="extension_backup_restore_type")
630
+ config_restore_button = gr.Button(value="Restore Selected Config", variant="primary", elem_id="extension_backup_restore")
631
+ with gr.Row(elem_id="extensions_backup_top_row2"):
632
+ config_save_name = gr.Textbox("", placeholder="Config Name", show_label=False)
633
+ config_save_button = gr.Button(value="Save Current Config")
634
+
635
+ config_states_info = gr.HTML("")
636
+ config_states_table = gr.HTML("Loading...")
637
+ ui.load(fn=update_config_states_table, inputs=[config_states_list], outputs=[config_states_table])
638
+
639
+ config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
640
+
641
+ dummy_component = gr.Label(visible=False)
642
+ config_restore_button.click(fn=restore_config_state, _js="config_state_confirm_restore", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info])
643
+
644
+ config_states_list.change(
645
+ fn=update_config_states_table,
646
+ inputs=[config_states_list],
647
+ outputs=[config_states_table],
648
+ )
649
+
650
+
651
+ return ui
modules/ui_extra_networks.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import urllib.parse
3
+ from pathlib import Path
4
+
5
+ from modules import shared, ui_extra_networks_user_metadata, errors
6
+ from modules.images import read_info_from_image, save_image_with_geninfo
7
+ from modules.ui import up_down_symbol
8
+ import gradio as gr
9
+ import json
10
+ import html
11
+
12
+ from modules.ui_components import ToolButton
13
+ from fastapi.exceptions import HTTPException
14
+
15
+ from modules.generation_parameters_copypaste import image_from_url_text
16
+ from modules.ui_components import ToolButton
17
+
18
+ extra_pages = []
19
+ allowed_dirs = set()
20
+ refresh_symbol = '\U0001f504' # 🔄
21
+ #clear_symbol = '\U0001F5D9' # 🗙
22
+
23
+ def register_page(page):
24
+ """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
25
+
26
+ extra_pages.append(page)
27
+ allowed_dirs.clear()
28
+ allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
29
+
30
+
31
+ def fetch_file(filename: str = ""):
32
+ from starlette.responses import FileResponse
33
+
34
+ if not os.path.isfile(filename):
35
+ raise HTTPException(status_code=404, detail="File not found")
36
+
37
+ if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
38
+ raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
39
+
40
+ ext = os.path.splitext(filename)[1].lower()
41
+ if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
42
+ raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
43
+
44
+ # would profit from returning 304
45
+ return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
46
+
47
+
48
+ def get_metadata(page: str = "", item: str = ""):
49
+ from starlette.responses import JSONResponse
50
+
51
+ page = next(iter([x for x in extra_pages if x.name == page]), None)
52
+ if page is None:
53
+ return JSONResponse({})
54
+
55
+ metadata = page.metadata.get(item)
56
+ if metadata is None:
57
+ return JSONResponse({})
58
+
59
+ return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
60
+
61
+
62
+ def get_single_card(page: str = "", tabname: str = "", name: str = ""):
63
+ from starlette.responses import JSONResponse
64
+
65
+ page = next(iter([x for x in extra_pages if x.name == page]), None)
66
+
67
+ try:
68
+ item = page.create_item(name, enable_filter=False)
69
+ page.items[name] = item
70
+ except Exception as e:
71
+ errors.display(e, "creating item for extra network")
72
+ item = page.items.get(name)
73
+
74
+ page.read_user_metadata(item)
75
+ item_html = page.create_html_for_item(item, tabname)
76
+
77
+ return JSONResponse({"html": item_html})
78
+
79
+
80
+ def add_pages_to_demo(app):
81
+ app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
82
+ app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
83
+ app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
84
+
85
+
86
+ def quote_js(s):
87
+ s = s.replace('\\', '\\\\')
88
+ s = s.replace('"', '\\"')
89
+ return f'"{s}"'
90
+
91
+
92
+ class ExtraNetworksPage:
93
+ def __init__(self, title):
94
+ self.title = title
95
+ self.name = title.lower()
96
+ self.id_page = self.name.replace(" ", "_")
97
+ self.card_page = shared.html("extra-networks-card.html")
98
+ self.allow_negative_prompt = False
99
+ self.metadata = {}
100
+ self.items = {}
101
+
102
+ def refresh(self):
103
+ pass
104
+
105
+ def read_user_metadata(self, item):
106
+ filename = item.get("filename", None)
107
+ basename, ext = os.path.splitext(filename)
108
+ metadata_filename = basename + '.json'
109
+
110
+ metadata = {}
111
+ try:
112
+ if os.path.isfile(metadata_filename):
113
+ with open(metadata_filename, "r", encoding="utf8") as file:
114
+ metadata = json.load(file)
115
+ except Exception as e:
116
+ errors.display(e, f"reading extra network user metadata from {metadata_filename}")
117
+
118
+ desc = metadata.get("description", None)
119
+ if desc is not None:
120
+ item["description"] = desc
121
+
122
+ item["user_metadata"] = metadata
123
+
124
+ def link_preview(self, filename):
125
+ quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
126
+ mtime = os.path.getmtime(filename)
127
+ return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
128
+
129
+ def search_terms_from_path(self, filename, possible_directories=None):
130
+ abspath = os.path.abspath(filename)
131
+
132
+ for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
133
+ parentdir = os.path.abspath(parentdir)
134
+ if abspath.startswith(parentdir):
135
+ return abspath[len(parentdir):].replace('\\', '/')
136
+
137
+ return ""
138
+
139
+
140
+ def create_html(self, tabname):
141
+ view = "cards" #shared.opts.extra_networks_default_view
142
+ items_html = ''
143
+
144
+ self.metadata = {}
145
+
146
+ subdirs = {}
147
+ for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
148
+ for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
149
+ for dirname in sorted(dirs, key=shared.natural_sort_key):
150
+ x = os.path.join(root, dirname)
151
+
152
+ if not os.path.isdir(x):
153
+ continue
154
+
155
+ subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
156
+ while subdir.startswith("/"):
157
+ subdir = subdir[1:]
158
+
159
+ is_empty = len(os.listdir(x)) == 0
160
+ if not is_empty and not subdir.endswith("/"):
161
+ subdir = subdir + "/"
162
+
163
+ if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
164
+ continue
165
+
166
+ subdirs[subdir] = 1
167
+
168
+ if subdirs:
169
+ subdirs = {"": 1, **subdirs}
170
+
171
+
172
+ #<option value='{html.escape(subdir if subdir!="" else "all")}'>{html.escape(subdir if subdir!="" else "all")}</option>
173
+ subdirs_html = "".join([f"""
174
+ <button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
175
+ {html.escape(subdir if subdir!="" else "all")}
176
+ </button>
177
+ """ for subdir in subdirs])
178
+
179
+ self.items = {x["name"]: x for x in self.list_items()}
180
+ for item in self.items.values():
181
+ metadata = item.get("metadata")
182
+ if metadata:
183
+ self.metadata[item["name"]] = metadata
184
+
185
+ if "user_metadata" not in item:
186
+ self.read_user_metadata(item)
187
+
188
+ items_html += self.create_html_for_item(item, tabname)
189
+
190
+ if items_html == '':
191
+ dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
192
+ items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
193
+
194
+ self_name_id = self.name.replace(" ", "_")
195
+
196
+ # <select onchange='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
197
+ # {subdirs_html}
198
+ # </select>
199
+ res = f"""
200
+ <div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-cards'>
201
+ {subdirs_html}
202
+ </div>
203
+ <div id='{tabname}_{self_name_id}_cards' class='extra-network-cards'>
204
+ {items_html}
205
+ </div>
206
+ """
207
+
208
+ return res
209
+
210
+ def create_item(self, name, index=None):
211
+ raise NotImplementedError()
212
+
213
+ def list_items(self):
214
+ raise NotImplementedError()
215
+
216
+ def allowed_directories_for_previews(self):
217
+ return []
218
+
219
+ def create_html_for_item(self, item, tabname):
220
+ """
221
+ Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown.
222
+ """
223
+
224
+ preview = item.get("preview", None)
225
+
226
+ onclick = item.get("onclick", None)
227
+ if onclick is None:
228
+ onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
229
+
230
+ #height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
231
+ #width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
232
+ background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
233
+
234
+ metadata_button = ""
235
+ metadata = item.get("metadata")
236
+ if metadata:
237
+ metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>"
238
+
239
+ edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>"
240
+
241
+ local_path = ""
242
+ filename = item.get("filename", "")
243
+
244
+ for reldir in self.allowed_directories_for_previews():
245
+ absdir = os.path.abspath(reldir)
246
+ if filename.startswith(absdir):
247
+ local_path = filename[len(absdir):]
248
+
249
+ # if this is true, the item must not be shown in the default view, and must instead only be
250
+ # shown when searching for it
251
+
252
+ if shared.opts.extra_networks_hidden_models == "Always":
253
+ search_only = False
254
+ else:
255
+ search_only = "/." in local_path or "\\." in local_path
256
+
257
+ if search_only and shared.opts.extra_networks_hidden_models == "Never":
258
+ return ""
259
+
260
+ sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
261
+
262
+ args = {
263
+ #"background_image": background_image,
264
+ #"style": f"'display: none; {height}{width}'",
265
+ "preview_image": html.escape(preview) if preview else './file=html/card-no-preview.png',
266
+ "prompt": item.get("prompt", None),
267
+ "tabname": quote_js(tabname),
268
+ "local_preview": quote_js(item["local_preview"]),
269
+ "name": item["name"],
270
+ "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
271
+ "card_clicked": onclick,
272
+ "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
273
+ "search_term": item.get("search_term", ""),
274
+ "metadata_button": metadata_button,
275
+ "edit_button": edit_button,
276
+ "search_only": " search_only" if search_only else "",
277
+ "sort_keys": sort_keys,
278
+ }
279
+
280
+ return self.card_page.format(**args)
281
+
282
+ def get_sort_keys(self, path):
283
+ """
284
+ List of default keys used for sorting in the UI.
285
+ """
286
+ pth = Path(path)
287
+ stat = pth.stat()
288
+ return {
289
+ "date_created": int(stat.st_ctime or 0),
290
+ "date_modified": int(stat.st_mtime or 0),
291
+ "name": pth.name.lower(),
292
+ }
293
+
294
+ def find_preview(self, path):
295
+ """
296
+ Find a preview PNG for a given path (without extension) and call link_preview on it.
297
+ """
298
+
299
+ preview_extensions = ["png", "jpg", "jpeg", "webp"]
300
+ if shared.opts.samples_format not in preview_extensions:
301
+ preview_extensions.append(shared.opts.samples_format)
302
+
303
+ # file_name = os.path.basename(path)
304
+ # location = os.path.dirname(path)
305
+ # preview_path = location + "/preview/" + file_name
306
+ # potential_files = sum([[path + "." + ext, path + ".preview." + ext, preview_path + "." + ext, preview_path + ".preview." + ext] for ext in preview_extensions], [])
307
+
308
+ potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
309
+
310
+ for file in potential_files:
311
+ if os.path.isfile(file):
312
+ return self.link_preview(file)
313
+
314
+ for file in potential_files:
315
+ if os.path.isfile(file):
316
+ return self.link_preview(file)
317
+
318
+ return None
319
+
320
+ def find_description(self, path):
321
+ """
322
+ Find and read a description file for a given path (without extension).
323
+ """
324
+ for file in [f"{path}.txt", f"{path}.description.txt"]:
325
+ try:
326
+ with open(file, "r", encoding="utf-8", errors="replace") as f:
327
+ return f.read()
328
+ except OSError:
329
+ pass
330
+ return None
331
+
332
+ def create_user_metadata_editor(self, ui, tabname):
333
+ return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self)
334
+
335
+
336
+ def initialize():
337
+ extra_pages.clear()
338
+
339
+
340
+ def register_default_pages():
341
+ from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
342
+ from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
343
+ from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
344
+ register_page(ExtraNetworksPageTextualInversion())
345
+ register_page(ExtraNetworksPageHypernetworks())
346
+ register_page(ExtraNetworksPageCheckpoints())
347
+
348
+
349
+ class ExtraNetworksUi:
350
+ def __init__(self):
351
+ self.pages = None
352
+ """gradio HTML components related to extra networks' pages"""
353
+
354
+ self.page_contents = None
355
+ """HTML content of the above; empty initially, filled when extra pages have to be shown"""
356
+
357
+ self.stored_extra_pages = None
358
+
359
+ self.button_save_preview = None
360
+ self.preview_target_filename = None
361
+
362
+ self.tabname = None
363
+
364
+
365
+ def pages_in_preferred_order(pages):
366
+ tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
367
+
368
+ def tab_name_score(name):
369
+ name = name.lower()
370
+ for i, possible_match in enumerate(tab_order):
371
+ if possible_match in name:
372
+ return i
373
+
374
+ return len(pages)
375
+
376
+ tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
377
+
378
+ return sorted(pages, key=lambda x: tab_scores[x.name])
379
+
380
+
381
+ def create_ui(container, button, tabname):
382
+ ui = ExtraNetworksUi()
383
+ ui.pages = []
384
+ ui.pages_contents = []
385
+ ui.user_metadata_editors = []
386
+ ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
387
+ ui.tabname = tabname
388
+
389
+ with gr.Accordion("Extra Networks", open=True):
390
+
391
+ with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
392
+ for page in ui.stored_extra_pages:
393
+ page_id = page.title.lower().replace(" ", "_")
394
+ with gr.Tab(page.title, id=page_id):
395
+ #elem_id = f"{tabname}_{page_id}_cards_html"
396
+ #page_elem = gr.HTML('Loading...', elem_id=elem_id)
397
+ page_elem = gr.HTML(page.create_html(ui.tabname))
398
+ ui.pages.append(page_elem)
399
+
400
+ page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
401
+
402
+ editor = page.create_user_metadata_editor(ui, tabname)
403
+ editor.create_ui()
404
+ ui.user_metadata_editors.append(editor)
405
+
406
+ gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
407
+ #gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
408
+ #gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder")
409
+ button_refresh = ToolButton(value=refresh_symbol, elem_id=tabname+"_extra_refresh")
410
+
411
+ ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
412
+ ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
413
+
414
+
415
+
416
+ def toggle_visibility(is_visible):
417
+ is_visible = not is_visible
418
+
419
+ return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
420
+
421
+ def fill_tabs(is_empty):
422
+ """Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time."""
423
+
424
+ if not ui.pages_contents:
425
+ refresh()
426
+
427
+ if is_empty:
428
+ return True, *ui.pages_contents
429
+
430
+ return True, *[gr.update() for _ in ui.pages_contents]
431
+
432
+ state_visible = gr.State(value=False)
433
+ button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False)
434
+
435
+ state_empty = gr.State(value=True)
436
+ button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False)
437
+
438
+ def refresh():
439
+ for pg in ui.stored_extra_pages:
440
+ pg.refresh()
441
+
442
+ ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
443
+
444
+ return ui.pages_contents
445
+
446
+ button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
447
+
448
+ return ui
449
+
450
+
451
+ def path_is_parent(parent_path, child_path):
452
+ parent_path = os.path.abspath(parent_path)
453
+ child_path = os.path.abspath(child_path)
454
+
455
+ return child_path.startswith(parent_path)
456
+
457
+
458
+ def setup_ui(ui, gallery):
459
+ def save_preview(index, images, filename):
460
+ # this function is here for backwards compatibility and likely will be removed soon
461
+
462
+ if len(images) == 0:
463
+ print("There is no image in gallery to save as a preview.")
464
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
465
+
466
+ index = int(index)
467
+ index = 0 if index < 0 else index
468
+ index = len(images) - 1 if index >= len(images) else index
469
+
470
+ img_info = images[index if index >= 0 else 0]
471
+ image = image_from_url_text(img_info)
472
+ geninfo, items = read_info_from_image(image)
473
+
474
+ is_allowed = False
475
+ for extra_page in ui.stored_extra_pages:
476
+ if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()):
477
+ is_allowed = True
478
+ break
479
+
480
+ assert is_allowed, f'writing to {filename} is not allowed'
481
+
482
+ save_image_with_geninfo(image, geninfo, filename)
483
+
484
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
485
+
486
+ ui.button_save_preview.click(
487
+ fn=save_preview,
488
+ _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
489
+ inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
490
+ outputs=[*ui.pages]
491
+ )
492
+
493
+ for editor in ui.user_metadata_editors:
494
+ editor.setup_ui(gallery)
495
+
496
+
modules/ui_extra_networks_checkpoints.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import os
3
+
4
+ from modules import shared, ui_extra_networks, sd_models
5
+ from modules.ui_extra_networks import quote_js
6
+
7
+
8
+ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
9
+ def __init__(self):
10
+ super().__init__('Checkpoints')
11
+
12
+ def refresh(self):
13
+ shared.refresh_checkpoints()
14
+
15
+ def create_item(self, name, index=None):
16
+ checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
17
+ path, ext = os.path.splitext(checkpoint.filename)
18
+ return {
19
+ "name": checkpoint.name_for_extra,
20
+ "filename": checkpoint.filename,
21
+ "preview": self.find_preview(path),
22
+ "description": self.find_description(path),
23
+ "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
24
+ "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
25
+ "local_preview": f"{path}.{shared.opts.samples_format}",
26
+ "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
27
+ }
28
+
29
+ def list_items(self):
30
+ for index, name in enumerate(sd_models.checkpoints_list):
31
+ yield self.create_item(name, index)
32
+
33
+ def allowed_directories_for_previews(self):
34
+ return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
35
+
modules/ui_extra_networks_hypernets.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from modules import shared, ui_extra_networks
4
+ from modules.ui_extra_networks import quote_js
5
+
6
+
7
+ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
8
+ def __init__(self):
9
+ super().__init__('Hypernetworks')
10
+
11
+ def refresh(self):
12
+ shared.reload_hypernetworks()
13
+
14
+ def create_item(self, name, index=None):
15
+ full_path = shared.hypernetworks[name]
16
+ path, ext = os.path.splitext(full_path)
17
+
18
+ return {
19
+ "name": name,
20
+ "filename": full_path,
21
+ "preview": self.find_preview(path),
22
+ "description": self.find_description(path),
23
+ "search_term": self.search_terms_from_path(path),
24
+ "prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
25
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
26
+ "sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
27
+ }
28
+
29
+ def list_items(self):
30
+ for index, name in enumerate(shared.hypernetworks):
31
+ yield self.create_item(name, index)
32
+
33
+ def allowed_directories_for_previews(self):
34
+ return [shared.cmd_opts.hypernetwork_dir]
35
+
modules/ui_extra_networks_textual_inversion.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from modules import ui_extra_networks, sd_hijack, shared
4
+ from modules.ui_extra_networks import quote_js
5
+
6
+
7
+ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
8
+ def __init__(self):
9
+ super().__init__('Textual Inversion')
10
+ self.allow_negative_prompt = True
11
+
12
+ def refresh(self):
13
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
14
+
15
+ def create_item(self, name, index=None):
16
+ embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
17
+
18
+ path, ext = os.path.splitext(embedding.filename)
19
+ return {
20
+ "name": name,
21
+ "filename": embedding.filename,
22
+ "preview": self.find_preview(path),
23
+ "description": self.find_description(path),
24
+ "search_term": self.search_terms_from_path(embedding.filename),
25
+ "prompt": quote_js(embedding.name),
26
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
27
+ "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
28
+ }
29
+
30
+ def list_items(self):
31
+ for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
32
+ yield self.create_item(name, index)
33
+
34
+ def allowed_directories_for_previews(self):
35
+ return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
modules/ui_extra_networks_user_metadata.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import html
3
+ import json
4
+ import os.path
5
+
6
+ import gradio as gr
7
+
8
+ from modules import generation_parameters_copypaste, images, sysinfo, errors
9
+
10
+
11
+ class UserMetadataEditor:
12
+
13
+ def __init__(self, ui, tabname, page):
14
+ self.ui = ui
15
+ self.tabname = tabname
16
+ self.page = page
17
+ self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata"
18
+
19
+ self.box = None
20
+
21
+ self.edit_name_input = None
22
+ self.button_edit = None
23
+
24
+ self.edit_name = None
25
+ self.edit_description = None
26
+ self.edit_notes = None
27
+ self.html_filedata = None
28
+ self.html_preview = None
29
+ self.html_status = None
30
+
31
+ self.button_cancel = None
32
+ self.button_replace_preview = None
33
+ self.button_save = None
34
+
35
+ def get_user_metadata(self, name):
36
+ item = self.page.items.get(name, {})
37
+
38
+ user_metadata = item.get('user_metadata', None)
39
+ if user_metadata is None:
40
+ user_metadata = {}
41
+ item['user_metadata'] = user_metadata
42
+
43
+ return user_metadata
44
+
45
+ def create_extra_default_items_in_left_column(self):
46
+ pass
47
+
48
+ def create_default_editor_elems(self):
49
+ with gr.Row():
50
+ with gr.Column(scale=2):
51
+ self.edit_name = gr.HTML(elem_classes="extra-network-name")
52
+ self.edit_description = gr.Textbox(label="Description", lines=4)
53
+ self.html_filedata = gr.HTML()
54
+
55
+ self.create_extra_default_items_in_left_column()
56
+
57
+ with gr.Column(scale=1, min_width=0):
58
+ self.html_preview = gr.HTML()
59
+
60
+ def create_default_buttons(self):
61
+
62
+ with gr.Row(elem_classes="edit-user-metadata-buttons"):
63
+ self.button_cancel = gr.Button('Cancel')
64
+ self.button_replace_preview = gr.Button('Replace preview', variant='primary')
65
+ self.button_save = gr.Button('Save', variant='primary')
66
+
67
+ self.html_status = gr.HTML(elem_classes="edit-user-metadata-status")
68
+
69
+ self.button_cancel.click(fn=None, _js="closePopup")
70
+
71
+ def get_card_html(self, name):
72
+ item = self.page.items.get(name, {})
73
+
74
+ preview_url = item.get("preview", None)
75
+
76
+ if not preview_url:
77
+ filename, _ = os.path.splitext(item["filename"])
78
+ preview_url = self.page.find_preview(filename)
79
+ item["preview"] = preview_url
80
+
81
+ if preview_url:
82
+ preview = f'''
83
+ <div class='card standalone-card-preview'>
84
+ <img src="{html.escape(preview_url)}" class="preview">
85
+ </div>
86
+ '''
87
+ else:
88
+ preview = "<div class='card standalone-card-preview'></div>"
89
+
90
+ return preview
91
+
92
+ def get_metadata_table(self, name):
93
+ item = self.page.items.get(name, {})
94
+ try:
95
+ filename = item["filename"]
96
+
97
+ stats = os.stat(filename)
98
+ params = [
99
+ ('File size: ', sysinfo.pretty_bytes(stats.st_size)),
100
+ ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
101
+ ]
102
+
103
+ return params
104
+ except Exception as e:
105
+ errors.display(e, f"reading info for {name}")
106
+ return []
107
+
108
+ def put_values_into_components(self, name):
109
+ user_metadata = self.get_user_metadata(name)
110
+
111
+ try:
112
+ params = self.get_metadata_table(name)
113
+ except Exception as e:
114
+ errors.display(e, f"reading metadata info for {name}")
115
+ params = []
116
+
117
+ table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'
118
+
119
+ return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
120
+
121
+ def write_user_metadata(self, name, metadata):
122
+ item = self.page.items.get(name, {})
123
+ filename = item.get("filename", None)
124
+ basename, ext = os.path.splitext(filename)
125
+
126
+ with open(basename + '.json', "w", encoding="utf8") as file:
127
+ json.dump(metadata, file)
128
+
129
+ def save_user_metadata(self, name, desc, notes):
130
+ user_metadata = self.get_user_metadata(name)
131
+ user_metadata["description"] = desc
132
+ user_metadata["notes"] = notes
133
+
134
+ self.write_user_metadata(name, user_metadata)
135
+
136
+ def setup_save_handler(self, button, func, components):
137
+ button\
138
+ .click(fn=func, inputs=[self.edit_name_input, *components], outputs=[])\
139
+ .then(fn=None, _js="function(name){closePopup(); extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", inputs=[self.edit_name_input], outputs=[])
140
+
141
+ def create_editor(self):
142
+ self.create_default_editor_elems()
143
+
144
+ self.edit_notes = gr.TextArea(label='Notes', lines=4)
145
+
146
+ self.create_default_buttons()
147
+
148
+ self.button_edit\
149
+ .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview, self.edit_notes])\
150
+ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
151
+
152
+ self.setup_save_handler(self.button_save, self.save_user_metadata, [self.edit_description, self.edit_notes])
153
+
154
+ def create_ui(self):
155
+ with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box:
156
+ self.box = box
157
+
158
+ self.edit_name_input = gr.Textbox("Edit user metadata card id", visible=False, elem_id=f"{self.id_part}_name")
159
+ self.button_edit = gr.Button("Edit user metadata", visible=False, elem_id=f"{self.id_part}_button")
160
+
161
+ self.create_editor()
162
+
163
+ def save_preview(self, index, gallery, name):
164
+ if len(gallery) == 0:
165
+ return self.get_card_html(name), "There is no image in gallery to save as a preview."
166
+
167
+ item = self.page.items.get(name, {})
168
+
169
+ index = int(index)
170
+ index = 0 if index < 0 else index
171
+ index = len(gallery) - 1 if index >= len(gallery) else index
172
+
173
+ img_info = gallery[index if index >= 0 else 0]
174
+ image = generation_parameters_copypaste.image_from_url_text(img_info)
175
+ geninfo, items = images.read_info_from_image(image)
176
+
177
+ images.save_image_with_geninfo(image, geninfo, item["local_preview"])
178
+
179
+ return self.get_card_html(name), ''
180
+
181
+ def setup_ui(self, gallery):
182
+ self.button_replace_preview.click(
183
+ fn=self.save_preview,
184
+ _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
185
+ inputs=[self.edit_name_input, gallery, self.edit_name_input],
186
+ outputs=[self.html_preview, self.html_status]
187
+ ).then(
188
+ fn=None,
189
+ _js="function(name){extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}",
190
+ inputs=[self.edit_name_input],
191
+ outputs=[]
192
+ )
193
+
194
+
195
+
modules/ui_gradio_extensions.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+
4
+ from modules import localization, shared, scripts
5
+ from modules.paths import script_path, data_path
6
+
7
+
8
+ def webpath(fn):
9
+ if fn.startswith(script_path):
10
+ web_path = os.path.relpath(fn, script_path).replace('\\', '/')
11
+ else:
12
+ web_path = os.path.abspath(fn)
13
+
14
+ return f'file={web_path}?{os.path.getmtime(fn)}'
15
+
16
+
17
+ def javascript_html():
18
+ # Ensure localization is in `window` before scripts
19
+ head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
20
+
21
+ script_js = os.path.join(script_path, "script.js")
22
+ head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
23
+
24
+ for script in scripts.list_scripts("javascript", ".js"):
25
+ head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
26
+
27
+ for script in scripts.list_scripts("javascript", ".mjs"):
28
+ head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
29
+
30
+ if shared.cmd_opts.theme:
31
+ head += f'<script type="text/javascript">set_theme(\"{shared.cmd_opts.theme}\");</script>\n'
32
+
33
+ return head
34
+
35
+
36
+ def css_html():
37
+ head = ""
38
+
39
+ def stylesheet(fn):
40
+ return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
41
+
42
+ for cssfile in scripts.list_files_with_name("style.css"):
43
+ if not os.path.isfile(cssfile):
44
+ continue
45
+
46
+ head += stylesheet(cssfile)
47
+
48
+ if os.path.exists(os.path.join(data_path, "user.css")):
49
+ head += stylesheet(os.path.join(data_path, "user.css"))
50
+
51
+ return head
52
+
53
+
54
+ def reload_javascript():
55
+ js = javascript_html()
56
+ css = css_html()
57
+
58
+ def template_response(*args, **kwargs):
59
+ res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
60
+ res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
61
+ res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
62
+ res.init_headers()
63
+ return res
64
+
65
+ gr.routes.templates.TemplateResponse = template_response
66
+
67
+
68
+ if not hasattr(shared, 'GradioTemplateResponseOriginal'):
69
+ shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
modules/ui_loadsave.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import gradio as gr
5
+
6
+ from modules import errors
7
+ from modules.ui_components import ToolButton
8
+
9
+
10
+ class UiLoadsave:
11
+ """allows saving and restorig default values for gradio components"""
12
+
13
+ def __init__(self, filename):
14
+ self.filename = filename
15
+ self.ui_settings = {}
16
+ self.component_mapping = {}
17
+ self.error_loading = False
18
+ self.finalized_ui = False
19
+
20
+ self.ui_defaults_view = None
21
+ self.ui_defaults_apply = None
22
+ self.ui_defaults_review = None
23
+
24
+ try:
25
+ if os.path.exists(self.filename):
26
+ self.ui_settings = self.read_from_file()
27
+ except Exception as e:
28
+ self.error_loading = True
29
+ errors.display(e, "loading settings")
30
+
31
+ def add_component(self, path, x):
32
+ """adds component to the registry of tracked components"""
33
+
34
+ assert not self.finalized_ui
35
+
36
+ def apply_field(obj, field, condition=None, init_field=None):
37
+ key = f"{path}/{field}"
38
+
39
+ if getattr(obj, 'custom_script_source', None) is not None:
40
+ key = f"customscript/{obj.custom_script_source}/{key}"
41
+
42
+ if getattr(obj, 'do_not_save_to_config', False):
43
+ return
44
+
45
+ saved_value = self.ui_settings.get(key, None)
46
+ if saved_value is None:
47
+ self.ui_settings[key] = getattr(obj, field)
48
+ elif condition and not condition(saved_value):
49
+ pass
50
+ else:
51
+ setattr(obj, field, saved_value)
52
+ if init_field is not None:
53
+ init_field(saved_value)
54
+
55
+ if field == 'value' and key not in self.component_mapping:
56
+ self.component_mapping[key] = x
57
+
58
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
59
+ apply_field(x, 'visible')
60
+
61
+ if type(x) == gr.Slider:
62
+ apply_field(x, 'value')
63
+ apply_field(x, 'minimum')
64
+ apply_field(x, 'maximum')
65
+ apply_field(x, 'step')
66
+
67
+ if type(x) == gr.Radio:
68
+ apply_field(x, 'value', lambda val: val in x.choices)
69
+
70
+ if type(x) == gr.Checkbox:
71
+ apply_field(x, 'value')
72
+
73
+ if type(x) == gr.Textbox:
74
+ apply_field(x, 'value')
75
+
76
+ if type(x) == gr.Number:
77
+ apply_field(x, 'value')
78
+
79
+ if type(x) == gr.Dropdown:
80
+ def check_dropdown(val):
81
+ if getattr(x, 'multiselect', False):
82
+ return all(value in x.choices for value in val)
83
+ else:
84
+ return val in x.choices
85
+
86
+ apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
87
+
88
+ def check_tab_id(tab_id):
89
+ tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
90
+ if type(tab_id) == str:
91
+ tab_ids = [t.id for t in tab_items]
92
+ return tab_id in tab_ids
93
+ elif type(tab_id) == int:
94
+ return 0 <= tab_id < len(tab_items)
95
+ else:
96
+ return False
97
+
98
+ if type(x) == gr.Tabs:
99
+ apply_field(x, 'selected', check_tab_id)
100
+
101
+ def add_block(self, x, path=""):
102
+ """adds all components inside a gradio block x to the registry of tracked components"""
103
+
104
+ if hasattr(x, 'children'):
105
+ if isinstance(x, gr.Tabs) and x.elem_id is not None:
106
+ # Tabs element can't have a label, have to use elem_id instead
107
+ self.add_component(f"{path}/Tabs@{x.elem_id}", x)
108
+ for c in x.children:
109
+ self.add_block(c, path)
110
+ elif x.label is not None:
111
+ self.add_component(f"{path}/{x.label}", x)
112
+ elif isinstance(x, gr.Button) and x.value is not None:
113
+ self.add_component(f"{path}/{x.value}", x)
114
+
115
+ def read_from_file(self):
116
+ with open(self.filename, "r", encoding="utf8") as file:
117
+ return json.load(file)
118
+
119
+ def write_to_file(self, current_ui_settings):
120
+ with open(self.filename, "w", encoding="utf8") as file:
121
+ json.dump(current_ui_settings, file, indent=4)
122
+
123
+ def dump_defaults(self):
124
+ """saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
125
+
126
+ if self.error_loading and os.path.exists(self.filename):
127
+ return
128
+
129
+ self.write_to_file(self.ui_settings)
130
+
131
+ def iter_changes(self, current_ui_settings, values):
132
+ """
133
+ given a dictionary with defaults from a file and current values from gradio elements, returns
134
+ an iterator over tuples of values that are not the same between the file and the current;
135
+ tuple contents are: path, old value, new value
136
+ """
137
+
138
+ for (path, component), new_value in zip(self.component_mapping.items(), values):
139
+ old_value = current_ui_settings.get(path)
140
+
141
+ choices = getattr(component, 'choices', None)
142
+ if isinstance(new_value, int) and choices:
143
+ if new_value >= len(choices):
144
+ continue
145
+
146
+ new_value = choices[new_value]
147
+
148
+ if new_value == old_value:
149
+ continue
150
+
151
+ if old_value is None and new_value == '' or new_value == []:
152
+ continue
153
+
154
+ yield path, old_value, new_value
155
+
156
+ def ui_view(self, *values):
157
+ text = ["<table><thead><tr><th>Path</th><th>Old value</th><th>New value</th></thead><tbody>"]
158
+
159
+ for path, old_value, new_value in self.iter_changes(self.read_from_file(), values):
160
+ if old_value is None:
161
+ old_value = "<span class='ui-defaults-none'>None</span>"
162
+
163
+ text.append(f"<tr><td>{path}</td><td>{old_value}</td><td>{new_value}</td></tr>")
164
+
165
+ if len(text) == 1:
166
+ text.append("<tr><td colspan=3>No changes</td></tr>")
167
+
168
+ text.append("</tbody>")
169
+ return "".join(text)
170
+
171
+ def ui_apply(self, *values):
172
+ num_changed = 0
173
+
174
+ current_ui_settings = self.read_from_file()
175
+
176
+ for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values):
177
+ num_changed += 1
178
+ current_ui_settings[path] = new_value
179
+
180
+ if num_changed == 0:
181
+ return "No changes."
182
+
183
+ self.write_to_file(current_ui_settings)
184
+
185
+ return f"Wrote {num_changed} changes."
186
+
187
+ def create_ui(self):
188
+ """creates ui elements for editing defaults UI, without adding any logic to them"""
189
+
190
+ gr.HTML(
191
+ f"This page allows you to change default values in UI elements on other tabs.<br />"
192
+ f"Make your changes, press 'View changes' to review the changed default values,<br />"
193
+ f"then press 'Apply' to write them to {self.filename}.<br />"
194
+ f"New defaults will apply after you restart the UI.<br />"
195
+ )
196
+
197
+ with gr.Row():
198
+ self.ui_defaults_view = gr.Button(value='View changes', elem_id="ui_defaults_view", variant="secondary")
199
+ self.ui_defaults_apply = gr.Button(value='Apply', elem_id="ui_defaults_apply", variant="primary")
200
+
201
+ self.ui_defaults_review = gr.HTML("")
202
+
203
+ def setup_ui(self):
204
+ """adds logic to elements created with create_ui; all add_block class must be made before this"""
205
+
206
+ assert not self.finalized_ui
207
+ self.finalized_ui = True
208
+
209
+ self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
210
+ self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
modules/ui_postprocessing.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from modules import scripts, shared, ui_common, postprocessing, call_queue
3
+ import modules.generation_parameters_copypaste as parameters_copypaste
4
+
5
+ def create_ui():
6
+ tab_index = gr.State(value=0)
7
+ gr.Row(elem_id="extras_2img_prompt_image", visible=False)
8
+ with gr.Row():
9
+ with gr.Column(elem_id="extras_2img_results"):
10
+ result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras_2img", shared.opts.outdir_extras_samples)
11
+ gr.Row(elem_id="extras_2img_splitter")
12
+ with gr.Column(variant='panel', elem_id="extras_2img_settings"):
13
+ submit = gr.Button('Upscale', elem_id="extras_generate", variant='primary')
14
+ with gr.Column(elem_id="extras_2img_settings_scroll"):
15
+ with gr.Accordion("Image Source", elem_id="extras_accordion", open=True):
16
+ with gr.Tabs(elem_id="mode_extras"):
17
+ with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
18
+ extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
19
+
20
+ with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
21
+ image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
22
+
23
+ with gr.TabItem('Batch from Directory', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
24
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
25
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
26
+ show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
27
+
28
+ script_inputs = scripts.scripts_postproc.setup_ui()
29
+
30
+ tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
31
+ tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
32
+ tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
33
+
34
+ submit.click(
35
+ fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
36
+ inputs=[
37
+ tab_index,
38
+ extras_image,
39
+ image_batch,
40
+ extras_batch_input_dir,
41
+ extras_batch_output_dir,
42
+ show_extras_results,
43
+ *script_inputs
44
+ ],
45
+ outputs=[
46
+ result_images,
47
+ html_info_x,
48
+ html_info,
49
+ ]
50
+ )
51
+
52
+ parameters_copypaste.add_paste_fields("extras", extras_image, None)
53
+
54
+ extras_image.change(
55
+ fn=scripts.scripts_postproc.image_changed,
56
+ inputs=[], outputs=[]
57
+ )
modules/ui_settings.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
4
+ from modules.call_queue import wrap_gradio_call
5
+ from modules.shared import opts
6
+ from modules.ui_components import FormRow
7
+ from modules.ui_gradio_extensions import reload_javascript
8
+
9
+
10
+ def get_value_for_setting(key):
11
+ value = getattr(opts, key)
12
+
13
+ info = opts.data_labels[key]
14
+ args = info.component_args() if callable(info.component_args) else info.component_args or {}
15
+ args = {k: v for k, v in args.items() if k not in {'precision'}}
16
+
17
+ return gr.update(value=value, **args)
18
+
19
+
20
+ def create_setting_component(key, is_quicksettings=False):
21
+ def fun():
22
+ return opts.data[key] if key in opts.data else opts.data_labels[key].default
23
+
24
+ info = opts.data_labels[key]
25
+ t = type(info.default)
26
+
27
+ args = info.component_args() if callable(info.component_args) else info.component_args
28
+
29
+ if info.component is not None:
30
+ comp = info.component
31
+ elif t == str:
32
+ comp = gr.Textbox
33
+ elif t == int:
34
+ comp = gr.Number
35
+ elif t == bool:
36
+ comp = gr.Checkbox
37
+ else:
38
+ raise Exception(f'bad options item type: {t} for key {key}')
39
+
40
+ elem_id = f"setting_{key}"
41
+
42
+ if info.refresh is not None:
43
+ if is_quicksettings:
44
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
45
+ ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
46
+ else:
47
+ with FormRow():
48
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
49
+ ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
50
+ else:
51
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
52
+
53
+ return res
54
+
55
+
56
+ class UiSettings:
57
+ submit = None
58
+ result = None
59
+ interface = None
60
+ components = None
61
+ component_dict = None
62
+ dummy_component = None
63
+ quicksettings_list = None
64
+ quicksettings_names = None
65
+ text_settings = None
66
+
67
+ def run_settings(self, *args):
68
+ changed = []
69
+
70
+ for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
71
+ assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
72
+
73
+ for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
74
+ if comp == self.dummy_component:
75
+ continue
76
+
77
+ if opts.set(key, value):
78
+ changed.append(key)
79
+
80
+ try:
81
+ opts.save(shared.config_filename)
82
+ except RuntimeError:
83
+ return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
84
+ return opts.dumpjson(), f'{len(changed)} settings changed{": " if changed else ""}{", ".join(changed)}.'
85
+
86
+ def run_settings_single(self, value, key):
87
+ if not opts.same_type(value, opts.data_labels[key].default):
88
+ return gr.update(visible=True), opts.dumpjson()
89
+
90
+ if not opts.set(key, value):
91
+ return gr.update(value=getattr(opts, key)), opts.dumpjson()
92
+
93
+ opts.save(shared.config_filename)
94
+
95
+ return get_value_for_setting(key), opts.dumpjson()
96
+
97
+ def create_ui(self, loadsave, dummy_component):
98
+ self.components = []
99
+ self.component_dict = {}
100
+ self.dummy_component = dummy_component
101
+
102
+ shared.settings_components = self.component_dict
103
+
104
+ script_callbacks.ui_settings_callback()
105
+ opts.reorder()
106
+
107
+ with gr.Blocks(analytics_enabled=False) as settings_interface:
108
+ with gr.Row():
109
+ with gr.Column(scale=6):
110
+ self.submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
111
+ with gr.Column():
112
+ restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
113
+
114
+ self.result = gr.HTML(elem_id="settings_result")
115
+
116
+ self.quicksettings_names = opts.quicksettings_list
117
+ self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'}
118
+
119
+ self.quicksettings_list = []
120
+
121
+ previous_section = None
122
+ current_tab = None
123
+ current_row = None
124
+ with gr.Tabs(elem_id="settings"):
125
+ for i, (k, item) in enumerate(opts.data_labels.items()):
126
+ section_must_be_skipped = item.section[0] is None
127
+
128
+ if previous_section != item.section and not section_must_be_skipped:
129
+ elem_id, text = item.section
130
+
131
+ if current_tab is not None:
132
+ current_row.__exit__()
133
+ current_tab.__exit__()
134
+
135
+ gr.Group()
136
+ current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
137
+ current_tab.__enter__()
138
+ current_row = gr.Column(variant='compact')
139
+ current_row.__enter__()
140
+
141
+ previous_section = item.section
142
+
143
+ if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings:
144
+ self.quicksettings_list.append((i, k, item))
145
+ self.components.append(dummy_component)
146
+ elif section_must_be_skipped:
147
+ self.components.append(dummy_component)
148
+ else:
149
+ component = create_setting_component(k)
150
+ self.component_dict[k] = component
151
+ self.components.append(component)
152
+
153
+ if current_tab is not None:
154
+ current_row.__exit__()
155
+ current_tab.__exit__()
156
+
157
+ with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
158
+ loadsave.create_ui()
159
+
160
+ with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
161
+ gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo">(or open as text in a new page)</a>', elem_id="sysinfo_download")
162
+
163
+ with gr.Row():
164
+ with gr.Column(scale=1):
165
+ sysinfo_check_file = gr.File(label="Check system info for validity", type='binary')
166
+ with gr.Column(scale=1):
167
+ sysinfo_check_output = gr.HTML("", elem_id="sysinfo_validity")
168
+ with gr.Column(scale=100):
169
+ pass
170
+
171
+ with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
172
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
173
+ download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
174
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
175
+ with gr.Row():
176
+ unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
177
+ reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
178
+
179
+ with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
180
+ gr.HTML(shared.html("licenses.html"), elem_id="licenses")
181
+
182
+ gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
183
+
184
+ self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
185
+
186
+ unload_sd_model.click(
187
+ fn=sd_models.unload_model_weights,
188
+ inputs=[],
189
+ outputs=[]
190
+ )
191
+
192
+ reload_sd_model.click(
193
+ fn=sd_models.reload_model_weights,
194
+ inputs=[],
195
+ outputs=[]
196
+ )
197
+
198
+ request_notifications.click(
199
+ fn=lambda: None,
200
+ inputs=[],
201
+ outputs=[],
202
+ _js='function(){}'
203
+ )
204
+
205
+ download_localization.click(
206
+ fn=lambda: None,
207
+ inputs=[],
208
+ outputs=[],
209
+ _js='download_localization'
210
+ )
211
+
212
+ def reload_scripts():
213
+ scripts.reload_script_body_only()
214
+ reload_javascript() # need to refresh the html page
215
+
216
+ reload_script_bodies.click(
217
+ fn=reload_scripts,
218
+ inputs=[],
219
+ outputs=[]
220
+ )
221
+
222
+ restart_gradio.click(
223
+ fn=shared.state.request_restart,
224
+ _js='restart_reload',
225
+ inputs=[],
226
+ outputs=[],
227
+ )
228
+
229
+ def check_file(x):
230
+ if x is None:
231
+ return ''
232
+
233
+ if sysinfo.check(x.decode('utf8', errors='ignore')):
234
+ return 'Valid'
235
+
236
+ return 'Invalid'
237
+
238
+ sysinfo_check_file.change(
239
+ fn=check_file,
240
+ inputs=[sysinfo_check_file],
241
+ outputs=[sysinfo_check_output],
242
+ )
243
+
244
+ self.interface = settings_interface
245
+
246
+ def add_quicksettings(self):
247
+ with gr.Row(elem_id="quicksettings", variant="compact"):
248
+ for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])):
249
+ component = create_setting_component(k, is_quicksettings=True)
250
+ self.component_dict[k] = component
251
+
252
+ def add_functionality(self, demo):
253
+ self.submit.click(
254
+ fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]),
255
+ inputs=self.components,
256
+ outputs=[self.text_settings, self.result],
257
+ )
258
+
259
+ for _i, k, _item in self.quicksettings_list:
260
+ component = self.component_dict[k]
261
+ info = opts.data_labels[k]
262
+
263
+ if isinstance(component, gr.Textbox):
264
+ methods = [component.submit, component.blur]
265
+ elif hasattr(component, 'release'):
266
+ methods = [component.release]
267
+ else:
268
+ methods = [component.change]
269
+
270
+ for method in methods:
271
+ method(
272
+ fn=lambda value, k=k: self.run_settings_single(value, key=k),
273
+ inputs=[component],
274
+ outputs=[component, self.text_settings],
275
+ show_progress=info.refresh is not None,
276
+ )
277
+
278
+ button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
279
+ button_set_checkpoint.click(
280
+ fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),
281
+ _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
282
+ inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],
283
+ outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],
284
+ )
285
+
286
+ component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]
287
+
288
+ def get_settings_values():
289
+ return [get_value_for_setting(key) for key in component_keys]
290
+
291
+ demo.load(
292
+ fn=get_settings_values,
293
+ inputs=[],
294
+ outputs=[self.component_dict[k] for k in component_keys],
295
+ queue=False,
296
+ )
modules/ui_tempdir.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from collections import namedtuple
4
+ from pathlib import Path
5
+
6
+ import gradio.components
7
+
8
+ from PIL import PngImagePlugin
9
+
10
+ from modules import shared
11
+
12
+
13
+ Savedfile = namedtuple("Savedfile", ["name"])
14
+
15
+
16
+ def register_tmp_file(gradio, filename):
17
+ if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
18
+ gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
19
+
20
+ if hasattr(gradio, 'temp_dirs'): # gradio 3.9
21
+ gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
22
+
23
+
24
+ def check_tmp_file(gradio, filename):
25
+ if hasattr(gradio, 'temp_file_sets'):
26
+ return any(filename in fileset for fileset in gradio.temp_file_sets)
27
+
28
+ if hasattr(gradio, 'temp_dirs'):
29
+ return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
30
+
31
+ return False
32
+
33
+
34
+ def save_pil_to_file(self, pil_image, dir=None, format="png"):
35
+ already_saved_as = getattr(pil_image, 'already_saved_as', None)
36
+ if already_saved_as and os.path.isfile(already_saved_as):
37
+ register_tmp_file(shared.demo, already_saved_as)
38
+ filename = already_saved_as
39
+
40
+ if not shared.opts.save_images_add_number:
41
+ filename += f'?{os.path.getmtime(already_saved_as)}'
42
+
43
+ return filename
44
+
45
+ if shared.opts.temp_dir != "":
46
+ dir = shared.opts.temp_dir
47
+
48
+ use_metadata = False
49
+ metadata = PngImagePlugin.PngInfo()
50
+ for key, value in pil_image.info.items():
51
+ if isinstance(key, str) and isinstance(value, str):
52
+ metadata.add_text(key, value)
53
+ use_metadata = True
54
+
55
+ file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
56
+ pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
57
+ return file_obj.name
58
+
59
+
60
+ # override save to file function so that it also writes PNG info
61
+ gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
62
+
63
+
64
+ def on_tmpdir_changed():
65
+ if shared.opts.temp_dir == "" or shared.demo is None:
66
+ return
67
+
68
+ os.makedirs(shared.opts.temp_dir, exist_ok=True)
69
+
70
+ register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
71
+
72
+
73
+ def cleanup_tmpdr():
74
+ temp_dir = shared.opts.temp_dir
75
+ if temp_dir == "" or not os.path.isdir(temp_dir):
76
+ return
77
+
78
+ for root, _, files in os.walk(temp_dir, topdown=False):
79
+ for name in files:
80
+ _, extension = os.path.splitext(name)
81
+ if extension != ".png":
82
+ continue
83
+
84
+ filename = os.path.join(root, name)
85
+ os.remove(filename)
modules/upscaler.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import abstractmethod
3
+
4
+ import PIL
5
+ from PIL import Image
6
+
7
+ import modules.shared
8
+ from modules import modelloader, shared
9
+
10
+ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
11
+ NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
12
+
13
+
14
+ class Upscaler:
15
+ name = None
16
+ model_path = None
17
+ model_name = None
18
+ model_url = None
19
+ enable = True
20
+ filter = None
21
+ model = None
22
+ user_path = None
23
+ scalers: []
24
+ tile = True
25
+
26
+ def __init__(self, create_dirs=False):
27
+ self.mod_pad_h = None
28
+ self.tile_size = modules.shared.opts.ESRGAN_tile
29
+ self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
30
+ self.device = modules.shared.device
31
+ self.img = None
32
+ self.output = None
33
+ self.scale = 1
34
+ self.half = not modules.shared.cmd_opts.no_half
35
+ self.pre_pad = 0
36
+ self.mod_scale = None
37
+ self.model_download_path = None
38
+
39
+ if self.model_path is None and self.name:
40
+ self.model_path = os.path.join(shared.models_path, self.name)
41
+ if self.model_path and create_dirs:
42
+ os.makedirs(self.model_path, exist_ok=True)
43
+
44
+ try:
45
+ import cv2 # noqa: F401
46
+ self.can_tile = True
47
+ except Exception:
48
+ pass
49
+
50
+ @abstractmethod
51
+ def do_upscale(self, img: PIL.Image, selected_model: str):
52
+ return img
53
+
54
+ def upscale(self, img: PIL.Image, scale, selected_model: str = None):
55
+ self.scale = scale
56
+ dest_w = int((img.width * scale) // 8 * 8)
57
+ dest_h = int((img.height * scale) // 8 * 8)
58
+
59
+ for _ in range(3):
60
+ shape = (img.width, img.height)
61
+
62
+ img = self.do_upscale(img, selected_model)
63
+
64
+ if shape == (img.width, img.height):
65
+ break
66
+
67
+ if img.width >= dest_w and img.height >= dest_h:
68
+ break
69
+
70
+ if img.width != dest_w or img.height != dest_h:
71
+ img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
72
+
73
+ return img
74
+
75
+ @abstractmethod
76
+ def load_model(self, path: str):
77
+ pass
78
+
79
+ def find_models(self, ext_filter=None) -> list:
80
+ return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter)
81
+
82
+ def update_status(self, prompt):
83
+ print(f"\nextras: {prompt}", file=shared.progress_print_out)
84
+
85
+
86
+ class UpscalerData:
87
+ name = None
88
+ data_path = None
89
+ scale: int = 4
90
+ scaler: Upscaler = None
91
+ model: None
92
+
93
+ def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
94
+ self.name = name
95
+ self.data_path = path
96
+ self.local_data_path = path
97
+ self.scaler = upscaler
98
+ self.scale = scale
99
+ self.model = model
100
+
101
+
102
+ class UpscalerNone(Upscaler):
103
+ name = "None"
104
+ scalers = []
105
+
106
+ def load_model(self, path):
107
+ pass
108
+
109
+ def do_upscale(self, img, selected_model=None):
110
+ return img
111
+
112
+ def __init__(self, dirname=None):
113
+ super().__init__(False)
114
+ self.scalers = [UpscalerData("None", None, self)]
115
+
116
+
117
+ class UpscalerLanczos(Upscaler):
118
+ scalers = []
119
+
120
+ def do_upscale(self, img, selected_model=None):
121
+ return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
122
+
123
+ def load_model(self, _):
124
+ pass
125
+
126
+ def __init__(self, dirname=None):
127
+ super().__init__(False)
128
+ self.name = "Lanczos"
129
+ self.scalers = [UpscalerData("Lanczos", None, self)]
130
+
131
+
132
+ class UpscalerNearest(Upscaler):
133
+ scalers = []
134
+
135
+ def do_upscale(self, img, selected_model=None):
136
+ return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
137
+
138
+ def load_model(self, _):
139
+ pass
140
+
141
+ def __init__(self, dirname=None):
142
+ super().__init__(False)
143
+ self.name = "Nearest"
144
+ self.scalers = [UpscalerData("Nearest", None, self)]
modules/xlmr.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertPreTrainedModel, BertConfig
2
+ import torch.nn as nn
3
+ import torch
4
+ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
5
+ from transformers import XLMRobertaModel,XLMRobertaTokenizer
6
+ from typing import Optional
7
+
8
+ class BertSeriesConfig(BertConfig):
9
+ def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
10
+
11
+ super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
12
+ self.project_dim = project_dim
13
+ self.pooler_fn = pooler_fn
14
+ self.learn_encoder = learn_encoder
15
+
16
+ class RobertaSeriesConfig(XLMRobertaConfig):
17
+ def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
18
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
19
+ self.project_dim = project_dim
20
+ self.pooler_fn = pooler_fn
21
+ self.learn_encoder = learn_encoder
22
+
23
+
24
+ class BertSeriesModelWithTransformation(BertPreTrainedModel):
25
+
26
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
27
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
28
+ config_class = BertSeriesConfig
29
+
30
+ def __init__(self, config=None, **kargs):
31
+ # modify initialization for autoloading
32
+ if config is None:
33
+ config = XLMRobertaConfig()
34
+ config.attention_probs_dropout_prob= 0.1
35
+ config.bos_token_id=0
36
+ config.eos_token_id=2
37
+ config.hidden_act='gelu'
38
+ config.hidden_dropout_prob=0.1
39
+ config.hidden_size=1024
40
+ config.initializer_range=0.02
41
+ config.intermediate_size=4096
42
+ config.layer_norm_eps=1e-05
43
+ config.max_position_embeddings=514
44
+
45
+ config.num_attention_heads=16
46
+ config.num_hidden_layers=24
47
+ config.output_past=True
48
+ config.pad_token_id=1
49
+ config.position_embedding_type= "absolute"
50
+
51
+ config.type_vocab_size= 1
52
+ config.use_cache=True
53
+ config.vocab_size= 250002
54
+ config.project_dim = 768
55
+ config.learn_encoder = False
56
+ super().__init__(config)
57
+ self.roberta = XLMRobertaModel(config)
58
+ self.transformation = nn.Linear(config.hidden_size,config.project_dim)
59
+ self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
60
+ self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
61
+ self.pooler = lambda x: x[:,0]
62
+ self.post_init()
63
+
64
+ def encode(self,c):
65
+ device = next(self.parameters()).device
66
+ text = self.tokenizer(c,
67
+ truncation=True,
68
+ max_length=77,
69
+ return_length=False,
70
+ return_overflowing_tokens=False,
71
+ padding="max_length",
72
+ return_tensors="pt")
73
+ text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
74
+ text["attention_mask"] = torch.tensor(
75
+ text['attention_mask']).to(device)
76
+ features = self(**text)
77
+ return features['projection_state']
78
+
79
+ def forward(
80
+ self,
81
+ input_ids: Optional[torch.Tensor] = None,
82
+ attention_mask: Optional[torch.Tensor] = None,
83
+ token_type_ids: Optional[torch.Tensor] = None,
84
+ position_ids: Optional[torch.Tensor] = None,
85
+ head_mask: Optional[torch.Tensor] = None,
86
+ inputs_embeds: Optional[torch.Tensor] = None,
87
+ encoder_hidden_states: Optional[torch.Tensor] = None,
88
+ encoder_attention_mask: Optional[torch.Tensor] = None,
89
+ output_attentions: Optional[bool] = None,
90
+ return_dict: Optional[bool] = None,
91
+ output_hidden_states: Optional[bool] = None,
92
+ ) :
93
+ r"""
94
+ """
95
+
96
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
97
+
98
+
99
+ outputs = self.roberta(
100
+ input_ids=input_ids,
101
+ attention_mask=attention_mask,
102
+ token_type_ids=token_type_ids,
103
+ position_ids=position_ids,
104
+ head_mask=head_mask,
105
+ inputs_embeds=inputs_embeds,
106
+ encoder_hidden_states=encoder_hidden_states,
107
+ encoder_attention_mask=encoder_attention_mask,
108
+ output_attentions=output_attentions,
109
+ output_hidden_states=True,
110
+ return_dict=return_dict,
111
+ )
112
+
113
+ # last module outputs
114
+ sequence_output = outputs[0]
115
+
116
+
117
+ # project every module
118
+ sequence_output_ln = self.pre_LN(sequence_output)
119
+
120
+ # pooler
121
+ pooler_output = self.pooler(sequence_output_ln)
122
+ pooler_output = self.transformation(pooler_output)
123
+ projection_state = self.transformation(outputs.last_hidden_state)
124
+
125
+ return {
126
+ 'pooler_output':pooler_output,
127
+ 'last_hidden_state':outputs.last_hidden_state,
128
+ 'hidden_states':outputs.hidden_states,
129
+ 'attentions':outputs.attentions,
130
+ 'projection_state':projection_state,
131
+ 'sequence_out': sequence_output
132
+ }
133
+
134
+
135
+ class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
136
+ base_model_prefix = 'roberta'
137
+ config_class= RobertaSeriesConfig
outputs/txt2img-images/2023-07-30/00000-4104476258.png ADDED
outputs/txt2img-images/2023-07-30/00001-1264812310.png ADDED
outputs/txt2img-images/2023-07-30/00002-629074369.png ADDED
outputs/txt2img-images/2023-07-30/00003-3929529382.png ADDED
outputs/txt2img-images/2023-07-30/00004-2891905160.png ADDED
outputs/txt2img-images/2023-07-30/00005-1703927525.png ADDED
outputs/txt2img-images/2023-07-30/00006-1703927525.png ADDED
outputs/txt2img-images/2023-07-30/00007-1703927525.png ADDED
outputs/txt2img-images/2023-07-30/00008-1703927525.png ADDED
outputs/txt2img-images/2023-07-30/00009-1703927525.jpg ADDED
outputs/txt2img-images/2023-07-30/00010-210755578.jpg ADDED
outputs/txt2img-images/2023-07-30/00011-3978311133.jpg ADDED
outputs/txt2img-images/2023-07-30/00012-3786155085.jpg ADDED
outputs/txt2img-images/2023-07-30/00013-445379948.jpg ADDED
outputs/txt2img-images/2023-07-30/00014-3277595636.png ADDED
package.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "stable-diffusion-webui",
3
+ "version": "0.0.0",
4
+ "devDependencies": {
5
+ "eslint": "^8.40.0"
6
+ },
7
+ "scripts": {
8
+ "lint": "eslint .",
9
+ "fix": "eslint --fix ."
10
+ }
11
+ }
params.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ masterpiece, best quality,
2
+ Negative prompt: (worst quality, low quality:1.4)
3
+ Steps: 20, Sampler: DPM++ 2M Karras, CFG scale: 7, Seed: 3277595636, Size: 512x768, Model hash: ee5e7d0285, Model: SCH_Excelsior, Version: 1.5.0
pyproject.toml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.ruff]
2
+
3
+ target-version = "py39"
4
+
5
+ extend-select = [
6
+ "B",
7
+ "C",
8
+ "I",
9
+ "W",
10
+ ]
11
+
12
+ exclude = [
13
+ "extensions",
14
+ "extensions-disabled",
15
+ ]
16
+
17
+ ignore = [
18
+ "E501", # Line too long
19
+ "E731", # Do not assign a `lambda` expression, use a `def`
20
+
21
+ "I001", # Import block is un-sorted or un-formatted
22
+ "C901", # Function is too complex
23
+ "C408", # Rewrite as a literal
24
+ "W605", # invalid escape sequence, messes with some docstrings
25
+ ]
26
+
27
+ [tool.ruff.per-file-ignores]
28
+ "webui.py" = ["E402"] # Module level import not at top of file
29
+
30
+ [tool.ruff.flake8-bugbear]
31
+ # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
32
+ extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]
33
+
34
+ [tool.pytest.ini_options]
35
+ base_url = "http://127.0.0.1:7860"
repositories/BLIP/BLIP.gif ADDED

Git LFS Details

  • SHA256: 7757a1a1133807158ec4e696a8187f289e64c30a86aa470d8e0a93948a02be22
  • Pointer size: 132 Bytes
  • Size of remote file: 6.71 MB
repositories/BLIP/CODEOWNERS ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
2
+ #ECCN:Open Source
repositories/BLIP/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Salesforce Open Source Community Code of Conduct
2
+
3
+ ## About the Code of Conduct
4
+
5
+ Equality is a core value at Salesforce. We believe a diverse and inclusive
6
+ community fosters innovation and creativity, and are committed to building a
7
+ culture where everyone feels included.
8
+
9
+ Salesforce open-source projects are committed to providing a friendly, safe, and
10
+ welcoming environment for all, regardless of gender identity and expression,
11
+ sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
12
+ race, age, religion, level of experience, education, socioeconomic status, or
13
+ other similar personal characteristics.
14
+
15
+ The goal of this code of conduct is to specify a baseline standard of behavior so
16
+ that people with different social values and communication styles can work
17
+ together effectively, productively, and respectfully in our open source community.
18
+ It also establishes a mechanism for reporting issues and resolving conflicts.
19
+
20
+ All questions and reports of abusive, harassing, or otherwise unacceptable behavior
21
+ in a Salesforce open-source project may be reported by contacting the Salesforce
22
+ Open Source Conduct Committee at ossconduct@salesforce.com.
23
+
24
+ ## Our Pledge
25
+
26
+ In the interest of fostering an open and welcoming environment, we as
27
+ contributors and maintainers pledge to making participation in our project and
28
+ our community a harassment-free experience for everyone, regardless of gender
29
+ identity and expression, sexual orientation, disability, physical appearance,
30
+ body size, ethnicity, nationality, race, age, religion, level of experience, education,
31
+ socioeconomic status, or other similar personal characteristics.
32
+
33
+ ## Our Standards
34
+
35
+ Examples of behavior that contributes to creating a positive environment
36
+ include:
37
+
38
+ * Using welcoming and inclusive language
39
+ * Being respectful of differing viewpoints and experiences
40
+ * Gracefully accepting constructive criticism
41
+ * Focusing on what is best for the community
42
+ * Showing empathy toward other community members
43
+
44
+ Examples of unacceptable behavior by participants include:
45
+
46
+ * The use of sexualized language or imagery and unwelcome sexual attention or
47
+ advances
48
+ * Personal attacks, insulting/derogatory comments, or trolling
49
+ * Public or private harassment
50
+ * Publishing, or threatening to publish, others' private information—such as
51
+ a physical or electronic address—without explicit permission
52
+ * Other conduct which could reasonably be considered inappropriate in a
53
+ professional setting
54
+ * Advocating for or encouraging any of the above behaviors
55
+
56
+ ## Our Responsibilities
57
+
58
+ Project maintainers are responsible for clarifying the standards of acceptable
59
+ behavior and are expected to take appropriate and fair corrective action in
60
+ response to any instances of unacceptable behavior.
61
+
62
+ Project maintainers have the right and responsibility to remove, edit, or
63
+ reject comments, commits, code, wiki edits, issues, and other contributions
64
+ that are not aligned with this Code of Conduct, or to ban temporarily or
65
+ permanently any contributor for other behaviors that they deem inappropriate,
66
+ threatening, offensive, or harmful.
67
+
68
+ ## Scope
69
+
70
+ This Code of Conduct applies both within project spaces and in public spaces
71
+ when an individual is representing the project or its community. Examples of
72
+ representing a project or community include using an official project email
73
+ address, posting via an official social media account, or acting as an appointed
74
+ representative at an online or offline event. Representation of a project may be
75
+ further defined and clarified by project maintainers.
76
+
77
+ ## Enforcement
78
+
79
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
80
+ reported by contacting the Salesforce Open Source Conduct Committee
81
+ at ossconduct@salesforce.com. All complaints will be reviewed and investigated
82
+ and will result in a response that is deemed necessary and appropriate to the
83
+ circumstances. The committee is obligated to maintain confidentiality with
84
+ regard to the reporter of an incident. Further details of specific enforcement
85
+ policies may be posted separately.
86
+
87
+ Project maintainers who do not follow or enforce the Code of Conduct in good
88
+ faith may face temporary or permanent repercussions as determined by other
89
+ members of the project's leadership and the Salesforce Open Source Conduct
90
+ Committee.
91
+
92
+ ## Attribution
93
+
94
+ This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
95
+ version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
96
+ It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
97
+ [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
98
+
99
+ This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
100
+
101
+ [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
102
+ [golang-coc]: https://golang.org/conduct
103
+ [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
104
+ [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
105
+ [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
repositories/BLIP/LICENSE.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022, Salesforce.com, Inc.
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+
6
+ * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7
+
8
+ * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9
+
10
+ * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11
+
12
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
repositories/BLIP/README.md ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
2
+
3
+ <img src="BLIP.gif" width="700">
4
+
5
+ This is the PyTorch code of the <a href="https://arxiv.org/abs/2201.12086">BLIP paper</a> [[blog](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)]. The code has been tested on PyTorch 1.10.
6
+ To install the dependencies, run <pre/>pip install -r requirements.txt</pre>
7
+
8
+ Catalog:
9
+ - [x] Inference demo
10
+ - [x] Pre-trained and finetuned checkpoints
11
+ - [x] Finetuning code for Image-Text Retrieval, Image Captioning, VQA, and NLVR2
12
+ - [x] Pre-training code
13
+ - [x] Zero-shot video-text retrieval
14
+ - [x] Download of bootstrapped pre-training datasets
15
+
16
+
17
+ ### Inference demo:
18
+ Run our interactive demo using [Colab notebook](https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb) (no GPU needed).
19
+ The demo includes code for:
20
+ 1. Image captioning
21
+ 2. Open-ended visual question answering
22
+ 3. Multimodal / unimodal feature extraction
23
+ 4. Image-text matching
24
+
25
+ Try out the [Web demo](https://huggingface.co/spaces/Salesforce/BLIP), integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio).
26
+
27
+ Replicate web demo and Docker image is also available at [![Replicate](https://replicate.com/salesforce/blip/badge)](https://replicate.com/salesforce/blip)
28
+
29
+ ### Pre-trained checkpoints:
30
+ Num. pre-train images | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
31
+ --- | :---: | :---: | :---:
32
+ 14M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth">Download</a>| - | -
33
+ 129M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth">Download</a> | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth">Download</a>
34
+
35
+ ### Finetuned checkpoints:
36
+ Task | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
37
+ --- | :---: | :---: | :---:
38
+ Image-Text Retrieval (COCO) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth">Download</a>
39
+ Image-Text Retrieval (Flickr30k) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_flickr.pth">Download</a>
40
+ Image Captioning (COCO) | - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth">Download</a> |
41
+ VQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth">Download</a> | -
42
+ NLVR2 | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth">Download</a>| - | -
43
+
44
+
45
+ ### Image-Text Retrieval:
46
+ 1. Download COCO and Flickr30k datasets from the original websites, and set 'image_root' in configs/retrieval_{dataset}.yaml accordingly.
47
+ 2. To evaluate the finetuned BLIP model on COCO, run:
48
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
49
+ --config ./configs/retrieval_coco.yaml \
50
+ --output_dir output/retrieval_coco \
51
+ --evaluate</pre>
52
+ 3. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/retrieval_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
53
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
54
+ --config ./configs/retrieval_coco.yaml \
55
+ --output_dir output/retrieval_coco </pre>
56
+
57
+ ### Image-Text Captioning:
58
+ 1. Download COCO and NoCaps datasets from the original websites, and set 'image_root' in configs/caption_coco.yaml and configs/nocaps.yaml accordingly.
59
+ 2. To evaluate the finetuned BLIP model on COCO, run:
60
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py --evaluate</pre>
61
+ 3. To evaluate the finetuned BLIP model on NoCaps, generate results with: (evaluation needs to be performed on official server)
62
+ <pre>python -m torch.distributed.run --nproc_per_node=8 eval_nocaps.py </pre>
63
+ 4. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/caption_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
64
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py </pre>
65
+
66
+ ### VQA:
67
+ 1. Download VQA v2 dataset and Visual Genome dataset from the original websites, and set 'vqa_root' and 'vg_root' in configs/vqa.yaml.
68
+ 2. To evaluate the finetuned BLIP model, generate results with: (evaluation needs to be performed on official server)
69
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_vqa.py --evaluate</pre>
70
+ 3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/vqa.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
71
+ <pre>python -m torch.distributed.run --nproc_per_node=16 train_vqa.py </pre>
72
+
73
+ ### NLVR2:
74
+ 1. Download NLVR2 dataset from the original websites, and set 'image_root' in configs/nlvr.yaml.
75
+ 2. To evaluate the finetuned BLIP model, run
76
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_nlvr.py --evaluate</pre>
77
+ 3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/nlvr.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
78
+ <pre>python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py </pre>
79
+
80
+ ### Finetune with ViT-L:
81
+ In order to finetune a model with ViT-L, simply change the config file to set 'vit' as large. Batch size and learning rate may also need to be adjusted accordingly (please see the paper's appendix for hyper-parameter details). <a href="https://github.com/facebookresearch/fairscale">Gradient checkpoint</a> can also be activated in the config file to reduce GPU memory usage.
82
+
83
+ ### Pre-train:
84
+ 1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
85
+ 2. In configs/pretrain.yaml, set 'train_file' as the paths for the json files .
86
+ 3. Pre-train the model using 8 A100 GPUs:
87
+ <pre>python -m torch.distributed.run --nproc_per_node=8 pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain </pre>
88
+
89
+ ### Zero-shot video-text retrieval:
90
+ 1. Download MSRVTT dataset following the instructions from https://github.com/salesforce/ALPRO, and set 'video_root' accordingly in configs/retrieval_msrvtt.yaml.
91
+ 2. Install [decord](https://github.com/dmlc/decord) with <pre>pip install decord</pre>
92
+ 3. To perform zero-shot evaluation, run
93
+ <pre>python -m torch.distributed.run --nproc_per_node=8 eval_retrieval_video.py</pre>
94
+
95
+ ### Pre-training datasets download:
96
+ We provide bootstrapped pre-training datasets as json files. Each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'url': url_of_image, 'caption': text_of_image}.
97
+
98
+ Image source | Filtered web caption | Filtered synthetic caption by ViT-B | Filtered synthetic caption by ViT-L
99
+ --- | :---: | :---: | :---:
100
+ CC3M+CC12M+SBU | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered_large.json">Download</a>
101
+ LAION115M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered_large.json">Download</a>
102
+
103
+ ### Citation
104
+ If you find this code to be useful for your research, please consider citing.
105
+ <pre>
106
+ @inproceedings{li2022blip,
107
+ title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation},
108
+ author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
109
+ year={2022},
110
+ booktitle={ICML},
111
+ }</pre>
112
+
113
+ ### Acknowledgement
114
+ The implementation of BLIP relies on resources from <a href="https://github.com/salesforce/ALBEF">ALBEF</a>, <a href="https://github.com/huggingface/transformers">Huggingface Transformers</a>, and <a href="https://github.com/rwightman/pytorch-image-models/tree/master/timm">timm</a>. We thank the original authors for their open-sourcing.
repositories/BLIP/SECURITY.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ## Security
2
+
3
+ Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
4
+ as soon as it is discovered. This library limits its runtime dependencies in
5
+ order to reduce the total cost of ownership as much as can be, but all consumers
6
+ should remain vigilant and have their security stakeholders review all third-party
7
+ products (3PP) like this one and their dependencies.
repositories/BLIP/cog.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ cuda: "11.1"
4
+ python_version: "3.8"
5
+ system_packages:
6
+ - "libgl1-mesa-glx"
7
+ - "libglib2.0-0"
8
+ python_packages:
9
+ - "ipython==7.30.1"
10
+ - "torchvision==0.11.1"
11
+ - "torch==1.10.0"
12
+ - "timm==0.4.12"
13
+ - "transformers==4.15.0"
14
+ - "fairscale==0.4.4"
15
+ - "pycocoevalcap==1.2"
16
+
17
+ predict: "predict.py:Predictor"
repositories/BLIP/configs/bert_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30522,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }