diff --git a/.gitattributes b/.gitattributes index 62b9d02b808ffa693cff84fd595b58ac743a5ab4..cef6514019122c8978d0e50384624abec00458a8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text extensions/Stable-Diffusion-Webui-Civitai-Helper/img/all_in_one.png filter=lfs diff=lfs merge=lfs -text extensions/addtional/models/lora/README.md filter=lfs diff=lfs merge=lfs -text +repositories/BLIP/BLIP.gif filter=lfs diff=lfs merge=lfs -text diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..953051409dd87a119ed54f60238336bd50396184 --- /dev/null +++ b/modules/textual_inversion/logging.py @@ -0,0 +1,64 @@ +import datetime +import json +import os + +saved_params_shared = { + "batch_size", + "clip_grad_mode", + "clip_grad_value", + "create_image_every", + "data_root", + "gradient_step", + "initial_step", + "latent_sampling_method", + "learn_rate", + "log_directory", + "model_hash", + "model_name", + "num_of_dataset_images", + "steps", + "template_file", + "training_height", + "training_width", +} +saved_params_ti = { + "embedding_name", + "num_vectors_per_token", + "save_embedding_every", + "save_image_with_stored_embedding", +} +saved_params_hypernet = { + "activation_func", + "add_layer_norm", + "hypernetwork_name", + "layer_structure", + "save_hypernetwork_every", + "use_dropout", + "weight_init", +} +saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet +saved_params_previews = { + "preview_cfg_scale", + "preview_height", + "preview_negative_prompt", + "preview_prompt", + "preview_sampler_index", + "preview_seed", + "preview_steps", + "preview_width", +} + + +def save_settings_to_file(log_directory, all_params): + now = datetime.datetime.now() + params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")} + + keys = saved_params_all + if all_params.get('preview_from_txt2img'): + keys = keys | saved_params_previews + + params.update({k: v for k, v in all_params.items() if k in keys}) + + filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json' + with open(os.path.join(log_directory, filename), "w") as file: + json.dump(params, file, indent=4) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..914adebb8c869d65c89541e6d8196bbcb750ccee --- /dev/null +++ b/modules/textual_inversion/preprocess.py @@ -0,0 +1,232 @@ +import os +from PIL import Image, ImageOps +import math +import tqdm + +from modules import paths, shared, images, deepbooru +from modules.textual_inversion import autocrop + + +def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): + try: + if process_caption: + shared.interrogator.load() + + if process_caption_deepbooru: + deepbooru.model.start() + + preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold) + + finally: + + if process_caption: + shared.interrogator.send_blip_to_ram() + + if process_caption_deepbooru: + deepbooru.model.stop() + + +def listfiles(dirname): + return os.listdir(dirname) + + +class PreprocessParams: + src = None + dstdir = None + subindex = 0 + flip = False + process_caption = False + process_caption_deepbooru = False + preprocess_txt_action = None + + +def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None): + caption = "" + + if params.process_caption: + caption += shared.interrogator.generate_caption(image) + + if params.process_caption_deepbooru: + if caption: + caption += ", " + caption += deepbooru.model.tag_multi(image) + + filename_part = params.src + filename_part = os.path.splitext(filename_part)[0] + filename_part = os.path.basename(filename_part) + + basename = f"{index:05}-{params.subindex}-{filename_part}" + image.save(os.path.join(params.dstdir, f"{basename}.png")) + + if params.preprocess_txt_action == 'prepend' and existing_caption: + caption = f"{existing_caption} {caption}" + elif params.preprocess_txt_action == 'append' and existing_caption: + caption = f"{caption} {existing_caption}" + elif params.preprocess_txt_action == 'copy' and existing_caption: + caption = existing_caption + + caption = caption.strip() + + if caption: + with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file: + file.write(caption) + + params.subindex += 1 + + +def save_pic(image, index, params, existing_caption=None): + save_pic_with_caption(image, index, params, existing_caption=existing_caption) + + if params.flip: + save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption) + + +def split_pic(image, inverse_xy, width, height, overlap_ratio): + if inverse_xy: + from_w, from_h = image.height, image.width + to_w, to_h = height, width + else: + from_w, from_h = image.width, image.height + to_w, to_h = width, height + h = from_h * to_w // from_w + if inverse_xy: + image = image.resize((h, to_w)) + else: + image = image.resize((to_w, h)) + + split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) + y_step = (h - to_h) / (split_count - 1) + for i in range(split_count): + y = int(y_step * i) + if inverse_xy: + splitted = image.crop((y, 0, y + to_h, to_w)) + else: + splitted = image.crop((0, y, to_w, y + to_h)) + yield splitted + +# not using torchvision.transforms.CenterCrop because it doesn't allow float regions +def center_crop(image: Image, w: int, h: int): + iw, ih = image.size + if ih / h < iw / w: + sw = w * ih / h + box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih + else: + sh = h * iw / w + box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2 + return image.resize((w, h), Image.Resampling.LANCZOS, box) + + +def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold): + iw, ih = image.size + err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h)) + wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64) + if minarea <= w * h <= maxarea and err(w, h) <= threshold), + key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1], + default=None + ) + return wh and center_crop(image, *wh) + + +def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): + width = process_width + height = process_height + src = os.path.abspath(process_src) + dst = os.path.abspath(process_dst) + split_threshold = max(0.0, min(1.0, split_threshold)) + overlap_ratio = max(0.0, min(0.9, overlap_ratio)) + + assert src != dst, 'same directory specified as source and destination' + + os.makedirs(dst, exist_ok=True) + + files = listfiles(src) + + shared.state.job = "preprocess" + shared.state.textinfo = "Preprocessing..." + shared.state.job_count = len(files) + + params = PreprocessParams() + params.dstdir = dst + params.flip = process_flip + params.process_caption = process_caption + params.process_caption_deepbooru = process_caption_deepbooru + params.preprocess_txt_action = preprocess_txt_action + + pbar = tqdm.tqdm(files) + for index, imagefile in enumerate(pbar): + params.subindex = 0 + filename = os.path.join(src, imagefile) + try: + img = Image.open(filename) + img = ImageOps.exif_transpose(img) + img = img.convert("RGB") + except Exception: + continue + + description = f"Preprocessing [Image {index}/{len(files)}]" + pbar.set_description(description) + shared.state.textinfo = description + + params.src = filename + + existing_caption = None + existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt" + if os.path.exists(existing_caption_filename): + with open(existing_caption_filename, 'r', encoding="utf8") as file: + existing_caption = file.read() + + if shared.state.interrupted: + break + + if img.height > img.width: + ratio = (img.width * height) / (img.height * width) + inverse_xy = False + else: + ratio = (img.height * width) / (img.width * height) + inverse_xy = True + + process_default_resize = True + + if process_split and ratio < 1.0 and ratio <= split_threshold: + for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio): + save_pic(splitted, index, params, existing_caption=existing_caption) + process_default_resize = False + + if process_focal_crop and img.height != img.width: + + dnn_model_path = None + try: + dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv")) + except Exception as e: + print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) + + autocrop_settings = autocrop.Settings( + crop_width = width, + crop_height = height, + face_points_weight = process_focal_crop_face_weight, + entropy_points_weight = process_focal_crop_entropy_weight, + corner_points_weight = process_focal_crop_edges_weight, + annotate_image = process_focal_crop_debug, + dnn_model_path = dnn_model_path, + ) + for focal in autocrop.crop_image(img, autocrop_settings): + save_pic(focal, index, params, existing_caption=existing_caption) + process_default_resize = False + + if process_multicrop: + cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold) + if cropped is not None: + save_pic(cropped, index, params, existing_caption=existing_caption) + else: + print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)") + process_default_resize = False + + if process_keep_original_size: + save_pic(img, index, params, existing_caption=existing_caption) + process_default_resize = False + + if process_default_resize: + img = images.resize_image(1, img, width, height) + save_pic(img, index, params, existing_caption=existing_caption) + + shared.state.nextjob() diff --git a/modules/textual_inversion/test_embedding.png b/modules/textual_inversion/test_embedding.png new file mode 100644 index 0000000000000000000000000000000000000000..07e2d9afaeaff3751b68a7c0f49d8b3466474282 Binary files /dev/null and b/modules/textual_inversion/test_embedding.png differ diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py new file mode 100644 index 0000000000000000000000000000000000000000..cf24f5497231b596792f02ab3674113ca4c80857 --- /dev/null +++ b/modules/textual_inversion/textual_inversion.py @@ -0,0 +1,683 @@ +import os +from collections import namedtuple +from contextlib import closing + +import torch +import tqdm +import html +import datetime +import csv +import safetensors.torch + +import numpy as np +from PIL import Image, PngImagePlugin +from torch.utils.tensorboard import SummaryWriter + +from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes +import modules.textual_inversion.dataset +from modules.textual_inversion.learn_schedule import LearnRateScheduler + +from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay +from modules.textual_inversion.logging import save_settings_to_file + + +TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"]) +textual_inversion_templates = {} + + +def list_textual_inversion_templates(): + textual_inversion_templates.clear() + + for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir): + for fn in fns: + path = os.path.join(root, fn) + + textual_inversion_templates[fn] = TextualInversionTemplate(fn, path) + + return textual_inversion_templates + + +class Embedding: + def __init__(self, vec, name, step=None): + self.vec = vec + self.name = name + self.step = step + self.shape = None + self.vectors = 0 + self.cached_checksum = None + self.sd_checkpoint = None + self.sd_checkpoint_name = None + self.optimizer_state_dict = None + self.filename = None + self.hash = None + self.shorthash = None + + def save(self, filename): + embedding_data = { + "string_to_token": {"*": 265}, + "string_to_param": {"*": self.vec}, + "name": self.name, + "step": self.step, + "sd_checkpoint": self.sd_checkpoint, + "sd_checkpoint_name": self.sd_checkpoint_name, + } + + torch.save(embedding_data, filename) + + if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None: + optimizer_saved_dict = { + 'hash': self.checksum(), + 'optimizer_state_dict': self.optimizer_state_dict, + } + torch.save(optimizer_saved_dict, f"{filename}.optim") + + def checksum(self): + if self.cached_checksum is not None: + return self.cached_checksum + + def const_hash(a): + r = 0 + for v in a: + r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF + return r + + self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' + return self.cached_checksum + + def set_hash(self, v): + self.hash = v + self.shorthash = self.hash[0:12] + + +class DirWithTextualInversionEmbeddings: + def __init__(self, path): + self.path = path + self.mtime = None + + def has_changed(self): + if not os.path.isdir(self.path): + return False + + mt = os.path.getmtime(self.path) + if self.mtime is None or mt > self.mtime: + return True + + def update(self): + if not os.path.isdir(self.path): + return + + self.mtime = os.path.getmtime(self.path) + + +class EmbeddingDatabase: + def __init__(self): + self.ids_lookup = {} + self.word_embeddings = {} + self.skipped_embeddings = {} + self.expected_shape = -1 + self.embedding_dirs = {} + self.previously_displayed_embeddings = () + + def add_embedding_dir(self, path): + self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) + + def clear_embedding_dirs(self): + self.embedding_dirs.clear() + + def register_embedding(self, embedding, model): + return self.register_embedding_by_name(embedding, model, embedding.name) + + def register_embedding_by_name(self, embedding, model, name): + ids = model.cond_stage_model.tokenize([name])[0] + first_id = ids[0] + if first_id not in self.ids_lookup: + self.ids_lookup[first_id] = [] + if name in self.word_embeddings: + # remove old one from the lookup list + lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name] + else: + lookup = self.ids_lookup[first_id] + if embedding is not None: + lookup += [(ids, embedding)] + self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True) + if embedding is None: + # unregister embedding with specified name + if name in self.word_embeddings: + del self.word_embeddings[name] + if len(self.ids_lookup[first_id])==0: + del self.ids_lookup[first_id] + return None + self.word_embeddings[name] = embedding + return embedding + + def get_expected_shape(self): + vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) + return vec.shape[1] + + def load_from_file(self, path, filename): + name, ext = os.path.splitext(filename) + ext = ext.upper() + + if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': + return + + embed_image = Image.open(path) + if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: + data = embedding_from_b64(embed_image.text['sd-ti-embedding']) + name = data.get('name', name) + else: + data = extract_image_data_embed(embed_image) + if data: + name = data.get('name', name) + else: + # if data is None, means this is not an embeding, just a preview image + return + elif ext in ['.BIN', '.PT']: + data = torch.load(path, map_location="cpu") + elif ext in ['.SAFETENSORS']: + data = safetensors.torch.load_file(path, device="cpu") + else: + return + + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + vec = emb.detach().to(devices.device, dtype=torch.float32) + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) + embedding.vectors = vec.shape[0] + embedding.shape = vec.shape[-1] + embedding.filename = path + embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '') + + if self.expected_shape == -1 or self.expected_shape == embedding.shape: + self.register_embedding(embedding, shared.sd_model) + else: + self.skipped_embeddings[name] = embedding + + def load_from_dir(self, embdir): + if not os.path.isdir(embdir.path): + return + + for root, _, fns in os.walk(embdir.path, followlinks=True): + for fn in fns: + try: + fullfn = os.path.join(root, fn) + + if os.stat(fullfn).st_size == 0: + continue + + self.load_from_file(fullfn, fn) + except Exception: + errors.report(f"Error loading embedding {fn}", exc_info=True) + continue + + def load_textual_inversion_embeddings(self, force_reload=False): + if not force_reload: + need_reload = False + for embdir in self.embedding_dirs.values(): + if embdir.has_changed(): + need_reload = True + break + + if not need_reload: + return + + self.ids_lookup.clear() + self.word_embeddings.clear() + self.skipped_embeddings.clear() + self.expected_shape = self.get_expected_shape() + + for embdir in self.embedding_dirs.values(): + self.load_from_dir(embdir) + embdir.update() + + # re-sort word_embeddings because load_from_dir may not load in alphabetic order. + # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it. + sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())} + self.word_embeddings.clear() + self.word_embeddings.update(sorted_word_embeddings) + + displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) + if shared.opts.textual_inversion_print_at_load and self.previously_displayed_embeddings != displayed_embeddings: + self.previously_displayed_embeddings = displayed_embeddings + print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") + if self.skipped_embeddings: + print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") + + def find_embedding_at_position(self, tokens, offset): + token = tokens[offset] + possible_matches = self.ids_lookup.get(token, None) + + if possible_matches is None: + return None, None + + for ids, embedding in possible_matches: + if tokens[offset:offset + len(ids)] == ids: + return embedding, len(ids) + + return None, None + + +def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): + cond_model = shared.sd_model.cond_stage_model + + with devices.autocast(): + cond_model([""]) # will send cond model to GPU if lowvram/medvram is active + + #cond_model expects at least some text, so we provide '*' as backup. + embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token) + vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) + + #Only copy if we provided an init_text, otherwise keep vectors as zeros + if init_text: + for i in range(num_vectors_per_token): + vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] + + # Remove illegal characters from name. + name = "".join( x for x in name if (x.isalnum() or x in "._- ")) + fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") + if not overwrite_old: + assert not os.path.exists(fn), f"file {fn} already exists" + + embedding = Embedding(vec, name) + embedding.step = 0 + embedding.save(fn) + + return fn + + +def write_loss(log_directory, filename, step, epoch_len, values): + if shared.opts.training_write_csv_every == 0: + return + + if step % shared.opts.training_write_csv_every != 0: + return + write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True + + with open(os.path.join(log_directory, filename), "a+", newline='') as fout: + csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())]) + + if write_csv_header: + csv_writer.writeheader() + + epoch = (step - 1) // epoch_len + epoch_step = (step - 1) % epoch_len + + csv_writer.writerow({ + "step": step, + "epoch": epoch, + "epoch_step": epoch_step, + **values, + }) + +def tensorboard_setup(log_directory): + os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) + return SummaryWriter( + log_dir=os.path.join(log_directory, "tensorboard"), + flush_secs=shared.opts.training_tensorboard_flush_every) + +def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num): + tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step) + tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step) + tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step) + tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step) + +def tensorboard_add_scaler(tensorboard_writer, tag, value, step): + tensorboard_writer.add_scalar(tag=tag, + scalar_value=value, global_step=step) + +def tensorboard_add_image(tensorboard_writer, tag, pil_image, step): + # Convert a pil image to a torch tensor + img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) + img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], + len(pil_image.getbands())) + img_tensor = img_tensor.permute((2, 0, 1)) + + tensorboard_writer.add_image(tag, img_tensor, global_step=step) + +def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"): + assert model_name, f"{name} not selected" + assert learn_rate, "Learning rate is empty or 0" + assert isinstance(batch_size, int), "Batch size must be integer" + assert batch_size > 0, "Batch size must be positive" + assert isinstance(gradient_step, int), "Gradient accumulation step must be integer" + assert gradient_step > 0, "Gradient accumulation step must be positive" + assert data_root, "Dataset directory is empty" + assert os.path.isdir(data_root), "Dataset directory doesn't exist" + assert os.listdir(data_root), "Dataset directory is empty" + assert template_filename, "Prompt template file not selected" + assert template_file, f"Prompt template file {template_filename} not found" + assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist" + assert steps, "Max steps is empty or 0" + assert isinstance(steps, int), "Max steps must be integer" + assert steps > 0, "Max steps must be positive" + assert isinstance(save_model_every, int), "Save {name} must be integer" + assert save_model_every >= 0, "Save {name} must be positive or 0" + assert isinstance(create_image_every, int), "Create image must be integer" + assert create_image_every >= 0, "Create image must be positive or 0" + if save_model_every or create_image_every: + assert log_directory, "Log directory is empty" + + +def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + save_embedding_every = save_embedding_every or 0 + create_image_every = create_image_every or 0 + template_file = textual_inversion_templates.get(template_filename, None) + validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + template_file = template_file.path + + shared.state.job = "train-embedding" + shared.state.textinfo = "Initializing textual inversion training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) + unload = shared.opts.unload_models_when_training + + if save_embedding_every > 0: + embedding_dir = os.path.join(log_directory, "embeddings") + os.makedirs(embedding_dir, exist_ok=True) + else: + embedding_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + if create_image_every > 0 and save_image_with_stored_embedding: + images_embeds_dir = os.path.join(log_directory, "image_embeddings") + os.makedirs(images_embeds_dir, exist_ok=True) + else: + images_embeds_dir = None + + hijack = sd_hijack.model_hijack + + embedding = hijack.embedding_db.word_embeddings[embedding_name] + checkpoint = sd_models.select_checkpoint() + + initial_step = embedding.step or 0 + if initial_step >= steps: + shared.state.textinfo = "Model has already been trained beyond specified max steps" + return embedding, filename + + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) + clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ + torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ + None + if clip_grad: + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) + # dataset loading may take a while, so input validations and early returns should be done before this + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + old_parallel_processing_allowed = shared.parallel_processing_allowed + + if shared.opts.training_enable_tensorboard: + tensorboard_writer = tensorboard_setup(log_directory) + + pin_memory = shared.opts.pin_memory + + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight) + + if shared.opts.save_training_settings_to_txt: + save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) + + latent_sampling_method = ds.latent_sampling_method + + dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) + + if unload: + shared.parallel_processing_allowed = False + shared.sd_model.first_stage_model.to(devices.cpu) + + embedding.vec.requires_grad = True + optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0) + if shared.opts.save_optimizer_state: + optimizer_state_dict = None + if os.path.exists(f"{filename}.optim"): + optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu') + if embedding.checksum() == optimizer_saved_dict.get('hash', None): + optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) + + if optimizer_state_dict is not None: + optimizer.load_state_dict(optimizer_state_dict) + print("Loaded existing optimizer from checkpoint") + else: + print("No saved optimizer exists in checkpoint") + + scaler = torch.cuda.amp.GradScaler() + + batch_size = ds.batch_size + gradient_step = ds.gradient_step + # n steps = batch_size * gradient_step * n image processed + steps_per_epoch = len(ds) // batch_size // gradient_step + max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step + loss_step = 0 + _loss_step = 0 #internal + + last_saved_file = "" + last_saved_image = "" + forced_filename = "" + embedding_yet_to_be_embedded = False + + is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'} + img_c = None + + pbar = tqdm.tqdm(total=steps - initial_step) + try: + sd_hijack_checkpoint.add() + + for _ in range((steps-initial_step) * gradient_step): + if scheduler.finished: + break + if shared.state.interrupted: + break + for j, batch in enumerate(dl): + # works as a drop_last=True for gradient accumulation + if j == max_steps_per_epoch: + break + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break + if shared.state.interrupted: + break + + if clip_grad: + clip_grad_sched.step(embedding.step) + + with devices.autocast(): + x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) + if use_weight: + w = batch.weight.to(devices.device, non_blocking=pin_memory) + c = shared.sd_model.cond_stage_model(batch.cond_text) + + if is_training_inpainting_model: + if img_c is None: + img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height) + + cond = {"c_concat": [img_c], "c_crossattn": [c]} + else: + cond = c + + if use_weight: + loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step + del w + else: + loss = shared.sd_model.forward(x, cond)[0] / gradient_step + del x + + _loss_step += loss.item() + scaler.scale(loss).backward() + + # go back until we reach gradient accumulation steps + if (j + 1) % gradient_step != 0: + continue + + if clip_grad: + clip_grad(embedding.vec, clip_grad_sched.learn_rate) + + scaler.step(optimizer) + scaler.update() + embedding.step += 1 + pbar.update() + optimizer.zero_grad(set_to_none=True) + loss_step = _loss_step + _loss_step = 0 + + steps_done = embedding.step + 1 + + epoch_num = embedding.step // steps_per_epoch + epoch_step = embedding.step % steps_per_epoch + + description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}" + pbar.set_description(description) + if embedding_dir is not None and steps_done % save_embedding_every == 0: + # Before saving, change name to match current checkpoint. + embedding_name_every = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') + save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) + embedding_yet_to_be_embedded = True + + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { + "loss": f"{loss_step:.7f}", + "learn_rate": scheduler.learn_rate + }) + + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + + shared.sd_model.first_stage_model.to(devices.device) + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + do_not_reload_embeddings=True, + ) + + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = batch.cond_text[0] + p.steps = 20 + p.width = training_width + p.height = training_height + + preview_text = p.prompt + + with closing(p): + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None + + if unload: + shared.sd_model.first_stage_model.to(devices.cpu) + + if image is not None: + shared.state.assign_current_image(image) + + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" + + if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: + tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) + + if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: + + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') + + info = PngImagePlugin.PngInfo() + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) + + title = f"<{data.get('name', '???')}>" + + try: + vectorSize = list(data['string_to_param'].values())[0].shape[0] + except Exception: + vectorSize = '?' + + checkpoint = sd_models.select_checkpoint() + footer_left = checkpoint.model_name + footer_mid = f'[{checkpoint.shorthash}]' + footer_right = f'{vectorSize}v {steps_done}s' + + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) + + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + embedding_yet_to_be_embedded = False + + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" + + shared.state.job_no = embedding.step + + shared.state.textinfo = f""" +

+Loss: {loss_step:.7f}
+Step: {steps_done}
+Last prompt: {html.escape(batch.cond_text[0])}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) + except Exception: + errors.report("Error training embedding", exc_info=True) + finally: + pbar.leave = False + pbar.close() + shared.sd_model.first_stage_model.to(devices.device) + shared.parallel_processing_allowed = old_parallel_processing_allowed + sd_hijack_checkpoint.remove() + + return embedding, filename + + +def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True): + old_embedding_name = embedding.name + old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None + old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None + old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None + try: + embedding.sd_checkpoint = checkpoint.shorthash + embedding.sd_checkpoint_name = checkpoint.model_name + if remove_cached_checksum: + embedding.cached_checksum = None + embedding.name = embedding_name + embedding.optimizer_state_dict = optimizer.state_dict() + embedding.save(filename) + except: + embedding.sd_checkpoint = old_sd_checkpoint + embedding.sd_checkpoint_name = old_sd_checkpoint_name + embedding.name = old_embedding_name + embedding.cached_checksum = old_cached_checksum + raise diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py new file mode 100644 index 0000000000000000000000000000000000000000..5b75f799e745fa693cda06763af80069324a964f --- /dev/null +++ b/modules/textual_inversion/ui.py @@ -0,0 +1,45 @@ +import html + +import gradio as gr + +import modules.textual_inversion.textual_inversion +import modules.textual_inversion.preprocess +from modules import sd_hijack, shared + + +def create_embedding(name, initialization_text, nvpt, overwrite_old): + filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text) + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" + + +def preprocess(*args): + modules.textual_inversion.preprocess.preprocess(*args) + + return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", "" + + +def train_embedding(*args): + + assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' + + apply_optimizations = shared.opts.training_xattention_optimizations + try: + if not apply_optimizations: + sd_hijack.undo_optimizations() + + embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. +Embedding saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + if not apply_optimizations: + sd_hijack.apply_optimizations() + diff --git a/modules/timer.py b/modules/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..22d1272dddfb5710525c04126a8b17e449c7a8d9 --- /dev/null +++ b/modules/timer.py @@ -0,0 +1,91 @@ +import time +import argparse + + +class TimerSubcategory: + def __init__(self, timer, category): + self.timer = timer + self.category = category + self.start = None + self.original_base_category = timer.base_category + + def __enter__(self): + self.start = time.time() + self.timer.base_category = self.original_base_category + self.category + "/" + self.timer.subcategory_level += 1 + + if self.timer.print_log: + print(f"{' ' * self.timer.subcategory_level}{self.category}:") + + def __exit__(self, exc_type, exc_val, exc_tb): + elapsed_for_subcategroy = time.time() - self.start + self.timer.base_category = self.original_base_category + self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy) + self.timer.subcategory_level -= 1 + self.timer.record(self.category, disable_log=True) + + +class Timer: + def __init__(self, print_log=False): + self.start = time.time() + self.records = {} + self.total = 0 + self.base_category = '' + self.print_log = print_log + self.subcategory_level = 0 + + def elapsed(self): + end = time.time() + res = end - self.start + self.start = end + return res + + def add_time_to_record(self, category, amount): + if category not in self.records: + self.records[category] = 0 + + self.records[category] += amount + + def record(self, category, extra_time=0, disable_log=False): + e = self.elapsed() + + self.add_time_to_record(self.base_category + category, e + extra_time) + + self.total += e + extra_time + + if self.print_log and not disable_log: + print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s") + + def subcategory(self, name): + self.elapsed() + + subcat = TimerSubcategory(self, name) + return subcat + + def summary(self): + res = f"{self.total:.1f}s" + + additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category] + if not additions: + return res + + res += " (" + res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions]) + res += ")" + + return res + + def dump(self): + return {'total': self.total, 'records': self.records} + + def reset(self): + self.__init__() + + +parser = argparse.ArgumentParser(add_help=False) +parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup") +args = parser.parse_known_args()[0] + +startup_timer = Timer(print_log=args.log_startup) + +startup_record = None diff --git a/modules/txt2img.py b/modules/txt2img.py new file mode 100644 index 0000000000000000000000000000000000000000..fa208f64dc4d94bbd154677935dacf696da0b456 --- /dev/null +++ b/modules/txt2img.py @@ -0,0 +1,73 @@ +from contextlib import closing + +import modules.scripts +from modules import sd_samplers, processing +from modules.generation_parameters_copypaste import create_override_settings_dict +from modules.shared import opts, cmd_opts +import modules.shared as shared +from modules.ui import plaintext_to_html +import gradio as gr + + +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): + override_settings = create_override_settings_dict(override_settings_texts) + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, + outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids, + prompt=prompt, + styles=prompt_styles, + negative_prompt=negative_prompt, + seed=seed, + subseed=subseed, + subseed_strength=subseed_strength, + seed_resize_from_h=seed_resize_from_h, + seed_resize_from_w=seed_resize_from_w, + seed_enable_extras=seed_enable_extras, + sampler_name=sd_samplers.samplers[sampler_index].name, + batch_size=batch_size, + n_iter=n_iter, + steps=steps, + cfg_scale=cfg_scale, + width=width, + height=height, + restore_faces=restore_faces, + tiling=tiling, + enable_hr=enable_hr, + denoising_strength=denoising_strength if enable_hr else None, + hr_scale=hr_scale, + hr_upscaler=hr_upscaler, + hr_second_pass_steps=hr_second_pass_steps, + hr_resize_x=hr_resize_x, + hr_resize_y=hr_resize_y, + hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, + hr_prompt=hr_prompt, + hr_negative_prompt=hr_negative_prompt, + override_settings=override_settings, + ) + + p.scripts = modules.scripts.scripts_txt2img + p.script_args = args + + p.user = request.username + + if cmd_opts.enable_console_prompts: + print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) + + with closing(p): + processed = modules.scripts.scripts_txt2img.run(p, *args) + + if processed is None: + processed = processing.process_images(p) + + shared.total_tqdm.clear() + + generation_info_js = processed.js() + if opts.samples_log_stdout: + print(generation_info_js) + + if opts.do_not_show_images: + processed.images = [] + + return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments") diff --git a/modules/ui.py b/modules/ui.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa24118860eb37732e25c8d295fd2c4e185a89d --- /dev/null +++ b/modules/ui.py @@ -0,0 +1,2084 @@ +import datetime +import json +import mimetypes +import os +import sys +from functools import reduce +import warnings + +import gradio as gr +import gradio.utils +import numpy as np +from PIL import Image, PngImagePlugin # noqa: F401 +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call + +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo +#from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave + + +from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML +from modules.paths import script_path, data_path + +from modules.shared import opts, cmd_opts + +import modules.codeformer_model +import modules.generation_parameters_copypaste as parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.scripts +import modules.shared as shared +import modules.styles +import modules.textual_inversion.ui +from modules import prompt_parser +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img +from modules.textual_inversion import textual_inversion +import modules.hypernetworks.ui +from modules.generation_parameters_copypaste import image_from_url_text +import modules.extras + +create_setting_component = ui_settings.create_setting_component + +warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +if not cmd_opts.share and not cmd_opts.listen: + # fix gradio phoning home + gradio.utils.version_check = lambda: None + gradio.utils.get_local_ip_address = lambda: '127.0.0.1' + +if cmd_opts.ngrok is not None: + import modules.ngrok as ngrok + print('ngrok authtoken detected, trying to connect...') + ngrok.connect( + cmd_opts.ngrok, + cmd_opts.port if cmd_opts.port is not None else 7860, + cmd_opts.ngrok_options + ) + + +def gr_show(visible=True): + return {"visible": visible, "__type__": "update"} + + +sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" +sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None + +# Using constants for these since the variation selector isn't visible. +# Important that they exactly match script.js for tooltip to work. +random_symbol = '\U0001f3b2\ufe0f' # 🎲️ +reuse_symbol = '\u267b\ufe0f' # ♻️ +paste_symbol = '\u2199\ufe0f' # ↙ +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +apply_style_symbol = '\U0001f4cb' # 📋 +clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️ +extra_networks_symbol = '\U0001F3B4' # 🎴 + +#switch_values_symbol = '\U000021C5' # ⇅ +switch_values_symbol = '\u2B80' # ⮀ +restore_progress_symbol = '\U0001F300' # 🌀 +detect_image_size_symbol = '\U0001F4D0' # 📐 +up_down_symbol = '\u2195\ufe0f' # ↕️ + +interogate_bubble_symbol = '\U0001F5E8' # 🗨 +interogate_2bubble_symbol = '\U0001F5EA' # 🗪 +generate_forever_symbol = '\u267E' # ♾ + +plaintext_to_html = ui_common.plaintext_to_html + + +def send_gradio_gallery_to_image(x): + if len(x) == 0: + return None + return image_from_url_text(x[0]) + + +def add_style(name: str, prompt: str, negative_prompt: str): + if name is None: + return [gr_show() for x in range(4)] + + style = modules.styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we + # reserialize all styles every time we save them + shared.prompt_styles.save_styles(shared.styles_filename) + + return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)] + + +def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): + from modules import processing, devices + + if not enable: + return "" + + p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) + + with devices.autocast(): + p.init([""], [0], [0]) + + return f"from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" + + +def resize_from_to_html(width, height, scale_by): + target_width = int(width * scale_by) + target_height = int(height * scale_by) + + if not target_width or not target_height: + return "no image selected" + + return f"from {width}x{height} to {target_width}x{target_height}" + + +def apply_styles(prompt, prompt_neg, styles): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])] + + +def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles): + if mode in {0, 1, 3, 4}: + return [interrogation_function(ii_singles[mode]), None] + elif mode == 2: + return [interrogation_function(ii_singles[mode]["image"]), None] + elif mode == 5: + assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" + images = shared.listfiles(ii_input_dir) + print(f"Will process {len(images)} images.") + if ii_output_dir != "": + os.makedirs(ii_output_dir, exist_ok=True) + else: + ii_output_dir = ii_input_dir + + for image in images: + img = Image.open(image) + filename = os.path.basename(image) + left, _ = os.path.splitext(filename) + print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8')) + + return [gr.update(), None] + + +def interrogate(image): + prompt = shared.interrogator.interrogate(image.convert("RGB")) + return gr.update() if prompt is None else prompt + + +def interrogate_deepbooru(image): + prompt = deepbooru.model.tag(image) + return gr.update() if prompt is None else prompt + + +def create_seed_inputs(target_interface): + + with gr.Row(elem_id = target_interface+"_group_seed"): + with gr.Box(): + with gr.Row(elem_id=target_interface + '_seed_row-collapse-all'): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') + random_seed = ToolButton(value=random_symbol, elem_id=target_interface + '_random_seed') + reuse_seed = ToolButton(value=reuse_symbol, elem_id=target_interface + '_reuse_seed') + + with gr.Box(elem_id='subseed_show_box-collapse-all'): + seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) + + + + # Components to show/hide based on the 'Extra' checkbox + seed_extras = [] + + + + # use sub-group + # at any place to indicate a different style already defined by the css rules + with gr.Column(elem_id=target_interface + '_subseed_row_sub-group', visible=False) as seed_extra_group: + + seed_extras.append(seed_extra_group) + + with gr.Row(visible=False) as seed_extra_row_1: + seed_extras.append(seed_extra_row_1) + with gr.Box(): + with gr.Row(elem_id= target_interface + '_subseed_row-collapse-all'): + subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') + #subseed.style(container=False) + random_subseed = ToolButton(value=random_symbol, elem_id=target_interface + '_random_subseed') + reuse_subseed = ToolButton(value=reuse_symbol, elem_id=target_interface + '_reuse_subseed') + with gr.Box(elem_id= target_interface + '_subseed_strength_row-collapse-all'): + #with gr.Row(elem_id= target_interface + '_subseed_strength_row'): + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') + + with gr.Row(visible=False) as seed_extra_row_2: + seed_extras.append(seed_extra_row_2) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') + + random_seed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_seed')}", show_progress=False, inputs=[], outputs=[]) + random_subseed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_subseed')}", show_progress=False, inputs=[], outputs=[]) + + def change_visibility(show): + return {comp: gr_show(show) for comp in seed_extras} + + seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) + + return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox + + + +def connect_clear_prompt(button): + """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" + button.click( + _js="clear_prompt", + fn=None, + inputs=[], + outputs=[], + ) + + +def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): + """ Connects a 'reuse (sub)seed' button's click event so that it copies last used + (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength + was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" + def copy_seed(gen_info_string: str, index): + res = -1 + + try: + gen_info = json.loads(gen_info_string) + index -= gen_info.get('index_of_first_image', 0) + + if is_subseed and gen_info.get('subseed_strength', 0) > 0: + all_subseeds = gen_info.get('all_subseeds', [-1]) + res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] + else: + all_seeds = gen_info.get('all_seeds', [-1]) + res = all_seeds[index if 0 <= index < len(all_seeds) else 0] + + except json.decoder.JSONDecodeError: + if gen_info_string: + errors.report(f"Error parsing JSON generation info: {gen_info_string}") + + return [res, gr_show(False)] + + reuse_seed.click( + fn=copy_seed, + _js="(x, y) => [x, selected_gallery_index()]", + show_progress=False, + inputs=[generation_info, dummy_component], + outputs=[seed, dummy_component] + ) + + +def update_token_counter(text, steps): + try: + text, _ = extra_networks.parse_prompt(text) + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step, prompt_text in flat_prompts] + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) + return f"{token_count}/{max_length}" + +def create_generate(is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + + with gr.Column(): + with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"): + interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt") + skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip") + submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + gr.Button(value=generate_forever_symbol, elem_id=f"{id_part}_generate_forever", variant='primary') + + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + + interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + return submit + +def create_toprow(is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + + with gr.Row(elem_id=f"{id_part}_toprow-collapse", variant="compact"): + with gr.Column(): + with gr.Row(): + with gr.Column(): + with gr.Column(elem_id=f"{id_part}_styles_row"): + prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) + create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_style_index") + + with gr.Row(): + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=True, lines=3, + placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Row(): + negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter") + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=True, lines=3, + placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Column(elem_id=f"{id_part}_actions_column"): + with gr.Row(elem_id=f"{id_part}_tools"): + paste = ToolButton(value=paste_symbol, elem_id="paste") + clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") + prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply") + save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create") + + button_interrogate = None + button_deepbooru = None + if is_img2img: + button_interrogate = ToolButton(value=interogate_bubble_symbol, elem_id="interrogate") + button_deepbooru = ToolButton(value=interogate_2bubble_symbol, elem_id="deepbooru") + + restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False) + + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") + + clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[prompt, negative_prompt], + outputs=[prompt, negative_prompt], + ) + + return prompt, prompt_styles, negative_prompt, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button + + + +def setup_progressbar(*args, **kwargs): + pass + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + +def apply_setting(key, value): + if value is None: + return gr.update() + + if shared.cmd_opts.freeze_settings: + return gr.update() + + # dont allow model to be swapped when model hash exists in prompt + if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: + return gr.update() + + if key == "sd_model_checkpoint": + ckpt_info = sd_models.get_closet_checkpoint_match(value) + + if ckpt_info is not None: + value = ckpt_info.title + else: + return gr.update() + + comp_args = opts.data_labels[key].component_args + if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: + return + + valtype = type(opts.data_labels[key].default) + oldval = opts.data.get(key, None) + opts.data[key] = valtype(value) if valtype != type(None) else value + if oldval != value and opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + return getattr(opts, key) + + +def create_output_panel(tabname, outdir): + return ui_common.create_output_panel(tabname, outdir) + + +def create_sampler_and_steps_selection(choices, tabname): + if opts.samplers_in_dropdown: + with FormRow(elem_id=f"sampler_selection_{tabname}"): + sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + else: + with FormGroup(elem_id=f"sampler_selection_{tabname}"): + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + + return steps, sampler_index + + +def ordered_ui_categories(): + user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)} + + for _, category in sorted(enumerate(shared_items.ui_reorder_categories()), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)): + yield category + +def get_value_for_setting(key): + value = getattr(opts, key) + + info = opts.data_labels[key] + args = info.component_args() if callable(info.component_args) else info.component_args or {} + args = {k: v for k, v in args.items() if k not in {'precision'}} + + return gr.update(value=value, **args) + +def create_override_settings_dropdown(tabname, row): + dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True) + + dropdown.change( + fn=lambda x: gr.Dropdown.update(visible=bool(x)), + inputs=[dropdown], + outputs=[dropdown], + ) + + return dropdown + + +def create_ui(): + import modules.img2img + import modules.txt2img + + reload_javascript() + + parameters_copypaste.reset() + + modules.scripts.scripts_current = modules.scripts.scripts_txt2img + modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + + with gr.Blocks(analytics_enabled=False) as txt2img_interface: + + #submit = create_generate(is_img2img=False) + + dummy_component = gr.Label(visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) + + with gr.Row(): + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) + gr.Row(elem_id="txt2img_splitter") + + with gr.Column(variant='panel', elem_id="txt2img_settings"): + + submit = create_generate(is_img2img=False) + + with gr.Column(elem_id="txt2img_settings_scroll"): + with gr.Accordion("Prompt", open=True): + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False) + + with gr.Row(elem_id="txt2img_extra_networks_row", visible=True) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img') + #modules.scripts.scripts_txt2img.prepare_ui() + + #with gr.Accordion("Parameters", open=True): + + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") + + elif category == "dimensions": + with gr.Row(elem_id="txt2img_dimensions_row", elem_classes="dimensions-tools"): + + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + + if opts.dimensions_and_batch_together: + with gr.Row(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "cfg": + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') + + elif category == "checkboxes": + #with FormRow(elem_id="txt2img_checkboxes", variant="compact"): + with gr.Row(elem_id="txt2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + + elif category == "hires_fix": + with gr.Column(visible=False, elem_id="txt2img_hires_fix_sub-group") as hr_options: + with gr.Row(elem_id="txt2img_hires_fix_row1"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) + + with gr.Row(): + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + + #with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"): + with gr.Row(elem_id="txt2img_hires_fix_row2"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + + with gr.Row(elem_id="txt2img_hires_fix_row3", visible=opts.hires_fix_show_sampler) as hr_sampler_container: + hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index") + + with gr.Row(elem_id="txt2img_hires_fix_row4", visible=opts.hires_fix_show_prompts) as hr_prompts_container: + with gr.Column(scale=80): + with gr.Row(): + hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"]) + with gr.Column(scale=80): + with gr.Row(): + hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"]) + + + elif category == "batch": + if not opts.dimensions_and_batch_together: + #with FormRow(elem_id="txt2img_column_batch"): + with gr.Row(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "override_settings": + with gr.Row(elem_id="txt2img_override_settings_row") as row: + override_settings = create_override_settings_dropdown('txt2img', row) + + elif category == "scripts": + #with FormRow(elem_id="txt2img_script_container"): + #with gr.Group(): + custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + + else: + modules.scripts.scripts_txt2img.setup_ui_for_section(category) + + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] + + for component in hr_resolution_preview_inputs: + event = component.release if isinstance(component, gr.Slider) else component.change + + event( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False, + ) + event( + None, + _js="onCalcResolutionHires", + inputs=hr_resolution_preview_inputs, + outputs=[], + show_progress=False, + ) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + txt2img_args = dict( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), + _js="submit", + inputs=[ + dummy_component, + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_styles, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + denoising_strength, + hr_scale, + hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, + hr_sampler_index, + hr_prompt, + hr_negative_prompt, + override_settings, + + ] + custom_inputs, + + outputs=[ + txt2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + #txt2img_prompt.submit(**txt2img_args) + submit.click(**txt2img_args) + + res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False) + + restore_progress_button.click( + fn=progress.restore_progress, + _js="restoreProgressTxt2img", + inputs=[dummy_component], + outputs=[ + txt2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + txt_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + txt_prompt_img + ], + outputs=[ + txt2img_prompt, + txt_prompt_img + ], + show_progress=False, + ) + + enable_hr.change( + fn=lambda x: gr_show(x), + inputs=[enable_hr], + outputs=[hr_options], + show_progress = False, + ) + + txt2img_paste_fields = [ + (txt2img_prompt, "Prompt"), + (txt2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), + (denoising_strength, "Denoising strength"), + (enable_hr, lambda d: "Denoising strength" in d), + (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (hr_scale, "Hires upscale"), + (hr_upscaler, "Hires upscaler"), + (hr_second_pass_steps, "Hires steps"), + (hr_resize_x, "Hires resize-1"), + (hr_resize_y, "Hires resize-2"), + (hr_sampler_index, "Hires sampler"), + (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()), + (hr_prompt, "Hires prompt"), + (hr_negative_prompt, "Hires negative prompt"), + (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), + *modules.scripts.scripts_txt2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, + )) + + txt2img_preview_params = [ + txt2img_prompt, + txt2img_negative_prompt, + steps, + sampler_index, + cfg_scale, + seed, + width, + height, + ] + + token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) + negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + + ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + + modules.scripts.scripts_current = modules.scripts.scripts_img2img + modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + + with gr.Blocks(analytics_enabled=False) as img2img_interface: + + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) + + with gr.Row(): + + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) + gr.Row(elem_id="img2img_splitter") + + with gr.Column(variant='panel', elem_id="img2img_settings"): + + submit = create_generate(is_img2img=True) + + with gr.Column(elem_id="img2img_settings_scroll"): + + with gr.Row(): + with gr.Accordion("Prompt", open=True): + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True) + + with gr.Row(elem_id="img2img_extra_networks_row", visible=True) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img') + + copy_image_buttons = [] + copy_image_destinations = {} + + def add_copy_image_controls(tab_name, elem): + with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"): + gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}") + + for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']): + if name == tab_name: + gr.Button(title, interactive=False) + copy_image_destinations[name] = elem + continue + + button = gr.Button(title) + copy_image_buttons.append((button, name, elem)) + + with gr.Accordion("Image Source", open=True): + with gr.Tabs(elem_id="mode_img2img"): + img2img_selected_tab = gr.State(0) + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA") + add_copy_image_controls('img2img', init_img) + + with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: + sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA") + add_copy_image_controls('sketch', sketch) + + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA") + add_copy_image_controls('inpaint', init_img_with_mask) + + with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA") + inpaint_color_sketch_orig = gr.State(None) + add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) + + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) + + with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask") + + with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML( + f"

Process images in a directory on the same machine where the server is running." + + f"
Use an empty output directory to save pictures normally instead of writing to the output directory." + + f"
Add inpaint batch mask directory to enable inpaint batch processing." + f"{hidden}

" + ) + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + with gr.Accordion("PNG info", open=False): + img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") + img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") + img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") + + + img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] + #img2img_image_inputs = [init_img, sketch, init_img_with_mask, inpaint_color_sketch] + + for i, tab in enumerate(img2img_tabs): + tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) + + for category in ordered_ui_categories(): + if category == "inpaint": + + with gr.Column(elem_id="dim_controls", visible=True): + with gr.Row(): + resize_mode = gr.Dropdown(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + + #with gr.Row(): + #with gr.Column(elem_id="img2img_column_size"): + # width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + # res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") + # height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + #with gr.Column(elem_id="img2img_column_size"): + selected_scale_tab = gr.State(value=0) + + with gr.Tabs(elem_id="scale_tabs"): + with gr.Tab(label="Resize to") as tab_scale_to: + with gr.Row(): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") + detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + + with gr.Tab(label="Resize by") as tab_scale_by: + scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") + with gr.Row(elem_id="img2img_scale_resolution_row"): + scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview") + gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider") + button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to") + + on_change_args = dict( + fn=resize_from_to_html, + _js="currentImg2imgSourceResolution", + inputs=[dummy_component, dummy_component, scale_by], + outputs=scale_by_html, + show_progress=False, + ) + + scale_by.release(**on_change_args) + button_update_resize_to.click(**on_change_args) + + for component in [init_img, sketch]: + component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) + + tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab]) + tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab]) + + with FormGroup(elem_id="inpaint_controls_sub-group-collapse", visible=False) as inpaint_controls: + + with gr.Row(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + + with gr.Row(): + inpainting_mask_invert = gr.Dropdown(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + inpainting_fill = gr.Dropdown(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with gr.Row(): + inpaint_full_res = gr.Dropdown(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + def select_img2img_tab(tab): + return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), + + for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]): + elem.select( + fn=lambda tab=i: select_img2img_tab(tab), + inputs=[], + outputs=[inpaint_controls, mask_alpha], + ) + + def copy_image(img): + if isinstance(img, dict) and 'image' in img: + return img['image'] + return img + + for button, name, elem in copy_image_buttons: + + button.click( + fn=copy_image, + inputs=[elem], + outputs=[copy_image_destinations[name]], + ) + button.click( + fn=lambda: None, + _js="switch_to_"+name.replace(" ", "_"), + inputs=[], + outputs=[], + ) + + for category in ordered_ui_categories(): + + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") + + elif category == "dimensions": + + if opts.dimensions_and_batch_together: + with gr.Row(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + + elif category == "cfg": + with gr.Row(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') + + elif category == "checkboxes": + #with FormRow(elem_id="img2img_checkboxes", variant="compact"): + with gr.Row(): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with gr.Row(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + + elif category == "override_settings": + with gr.Row(elem_id="img2img_override_settings_row") as row: + override_settings = create_override_settings_dropdown('img2img', row) + + + elif category == "scripts": + #with FormGroup(elem_id="img2img_script_container"): + custom_inputs = modules.scripts.scripts_img2img.setup_ui() + + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + img2img_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + img2img_prompt_img + ], + outputs=[ + img2img_prompt, + img2img_prompt_img + ], + show_progress=False, + ) + + img2img_args = dict( + fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), + _js="submit_img2img", + inputs=[ + dummy_component, + dummy_component, + img2img_prompt, + img2img_negative_prompt, + img2img_prompt_styles, + init_img, + sketch, + init_img_with_mask, + inpaint_color_sketch, + inpaint_color_sketch_orig, + init_img_inpaint, + init_mask_inpaint, + steps, + sampler_index, + mask_blur, + mask_alpha, + inpainting_fill, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + image_cfg_scale, + denoising_strength, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + selected_scale_tab, + height, + width, + scale_by, + resize_mode, + inpaint_full_res, + inpaint_full_res_padding, + inpainting_mask_invert, + img2img_batch_input_dir, + img2img_batch_output_dir, + img2img_batch_inpaint_mask_dir, + override_settings, + img2img_batch_use_png_info, + img2img_batch_png_info_props, + img2img_batch_png_info_dir, + ] + custom_inputs, + outputs=[ + img2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + interrogate_args = dict( + _js="get_img2img_tab_index", + inputs=[ + dummy_component, + img2img_batch_input_dir, + img2img_batch_output_dir, + init_img, + sketch, + init_img_with_mask, + inpaint_color_sketch, + init_img_inpaint, + ], + outputs=[img2img_prompt, dummy_component], + ) + + img2img_prompt.submit(**img2img_args) + submit.click(**img2img_args) + + res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False) + + detect_image_size_btn.click( + fn=lambda w, h, _: (w or gr.update(), h or gr.update()), + _js="currentImg2imgSourceResolution", + inputs=[dummy_component, dummy_component, dummy_component], + outputs=[width, height], + show_progress=False, + ) + + restore_progress_button.click( + fn=progress.restore_progress, + _js="restoreProgressImg2img", + inputs=[dummy_component], + outputs=[ + img2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + img2img_interrogate.click( + fn=lambda *args: process_interrogate(interrogate, *args), + **interrogate_args, + ) + + img2img_deepbooru.click( + fn=lambda *args: process_interrogate(interrogate_deepbooru, *args), + **interrogate_args, + ) + + prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] + style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles] + style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] + + for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): + button.click( + fn=add_style, + _js="ask_for_style_name", + # Have to pass empty dummy component here, because the JavaScript and Python function have to accept + # the same number of parameters, but we only know the style-name after the JavaScript prompt + inputs=[dummy_component, prompt, negative_prompt], + outputs=[txt2img_prompt_styles, img2img_prompt_styles], + ) + + for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): + button.click( + fn=apply_styles, + _js=js_func, + inputs=[prompt, negative_prompt, styles], + outputs=[prompt, negative_prompt, styles], + ) + + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) + negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter]) + + ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + + img2img_paste_fields = [ + (img2img_prompt, "Prompt"), + (img2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (image_cfg_scale, "Image CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), + (denoising_strength, "Denoising strength"), + (mask_blur, "Mask blur"), + *modules.scripts.scripts_img2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) + parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, + )) + + modules.scripts.scripts_current = None + + with gr.Blocks(analytics_enabled=False) as extras_interface: + ui_postprocessing.create_ui() + + with gr.Blocks(analytics_enabled=False) as pnginfo_interface: + gr.Row(elem_id="png_2img_prompt_image", visible=False) + with gr.Row(): + with gr.Column(elem_id="png_2img_results"): + with gr.Row(): + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) + + html = gr.HTML() + generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") + html2 = gr.HTML() + + gr.Row(elem_id="png_2img_splitter") + with gr.Column(variant='panel', elem_id="png_2img_settings"): + with gr.Column(elem_id="png_2img_settings_scroll"): + image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") + + for tabname, button in buttons.items(): + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image, + )) + + image.change( + fn=wrap_gradio_call(modules.extras.run_pnginfo), + inputs=[image], + outputs=[html, generation_info, html2], + ) + + def update_interp_description(value): + interp_description_css = "

{}

" + interp_descriptions = { + "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."), + "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"), + "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M") + } + return interp_descriptions[value] + + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + gr.Row(elem_id="modelmerger_2img_prompt_image", visible=False) + with gr.Row(): + with gr.Column(elem_id="modelmerger_2img_results"): + with gr.Group(elem_id="modelmerger_results_panel"): + modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False) + + gr.Row(elem_id="modelmerger_2img_splitter") + with gr.Column(variant='panel', elem_id="modelmerger_2img_settings"): + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') + with gr.Column(elem_id="modelmerger_2img_settings_scroll"): + interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description") + + with FormRow(elem_id="modelmerger_models"): + with gr.Box(): + with gr.Row(elem_id="modelmerger_primary_row-collapse-all"): + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") + with gr.Box(): + with gr.Row(elem_id="modelmerger_secondary_row-collapse-all"): + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") + with gr.Box(): + with gr.Row(elem_id="modelmerger_tertiary_row-collapse-all"): + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") + + custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") + interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description]) + + with FormRow(): + checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") + save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") + save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata") + + with FormRow(): + with gr.Column(): + config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + + with gr.Column(): + with gr.Box(): + with gr.Row(elem_id="modelmerger_bake_in_vae_row-collapse-all"): + bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae") + create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "refresh_modelmerger_bake_in_vae") + + with FormRow(): + discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights") + + with gr.Blocks(analytics_enabled=False) as train_interface: + #with gr.Row(elem_id="textual_inversion_wiki"): + # gr.HTML(value="

See wiki for detailed explanation.

") + + gr.Row(elem_id="ti_2img_prompt_image", visible=False) + with gr.Row(): + with gr.Column(elem_id="ti_2img_results"): + with gr.Column(elem_id='ti_gallery_container'): + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + gr.Row(elem_id="ti_2img_splitter") + with gr.Tabs(elem_id="train_tabs_2img_settings"): + with gr.Tab(label="Create embedding"): + #create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") + with gr.Column(elem_id="embedding_2img_settings_scroll"): + new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") + initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") + + with gr.Row(): + with gr.Column(): + create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") + with gr.Row(): + gr.HTML(value="") + + with gr.Tab(label="Create hypernetwork"): + #create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") + with gr.Column(elem_id="hypernetwork_2img_settings_scroll"): + new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") + + with gr.Row(): + with gr.Column(): + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") + with gr.Row(): + gr.HTML(value="") + + + with gr.Tab(label="Preprocess images", id="preprocess_images"): + # with gr.Column(): + # with gr.Row(): + # interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") + # run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") + with gr.Column(elem_id="preprocess_2img_settings_scroll"): + process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") + process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") + process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") + process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") + + with gr.Row(): + process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size") + process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") + process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") + process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") + process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop") + process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") + + with gr.Row(visible=False) as process_split_extra_row: + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") + + with gr.Row(visible=False) as process_focal_crop_row: + process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") + process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") + process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + with gr.Column(visible=False) as process_multicrop_col: + gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') + with gr.Row(): + process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim") + process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim") + with gr.Row(): + process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea") + process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea") + with gr.Row(): + process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective") + process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold") + + with gr.Row(): + with gr.Column(): + with gr.Row(): + interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") + run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") + with gr.Row(): + gr.HTML(value="") + + process_split.change( + fn=lambda show: gr_show(show), + inputs=[process_split], + outputs=[process_split_extra_row], + ) + + process_focal_crop.change( + fn=lambda show: gr_show(show), + inputs=[process_focal_crop], + outputs=[process_focal_crop_row], + ) + + process_multicrop.change( + fn=lambda show: gr_show(show), + inputs=[process_multicrop], + outputs=[process_multicrop_col], + ) + + def get_textual_inversion_template_names(): + return sorted(textual_inversion.textual_inversion_templates) + + + with gr.Tab(label="Train", id="train"): + + with gr.Column(elem_id="train_2img_settings_scroll"): + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") + with FormRow(): + with gr.Box(): + with gr.Row(elem_id="train_embedding_row-collapse-all"): + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") + + with gr.Box(): + with gr.Row(elem_id="train_hypernetwork_row-collapse-all"): + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) + create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") + + with FormRow(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") + + with FormRow(): + with gr.Box(): + with gr.Row(elem_id="train_gradient_clipping_row-collapse-all"): + clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) + clip_grad_value = gr.Textbox(label="Value", value="0.1") + + with FormRow(): + batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") + + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") + + with FormRow(): + with gr.Box(): + with gr.Row(elem_id="train_template_file_row-collapse-all"): + template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refresh_train_template_file") + + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") + training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") + steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") + + with FormRow(): + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") + + use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight") + + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") + + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") + + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") + + with gr.Row(): + train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") + interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") + + params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + script_callbacks.ui_train_tabs_callback(params) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + initialization_text, + nvpt, + overwrite_old_embedding, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + create_hypernetwork.click( + fn=modules.hypernetworks.ui.create_hypernetwork, + inputs=[ + new_hypernetwork_name, + new_hypernetwork_sizes, + overwrite_old_hypernetwork, + new_hypernetwork_layer_structure, + new_hypernetwork_activation_func, + new_hypernetwork_initialization_option, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout, + new_hypernetwork_dropout_structure + ], + outputs=[ + train_hypernetwork_name, + ti_output, + ti_outcome, + ] + ) + + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + dummy_component, + process_src, + process_dst, + process_width, + process_height, + preprocess_txt_action, + process_keep_original_size, + process_flip, + process_split, + process_caption, + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, + process_focal_crop, + process_focal_crop_face_weight, + process_focal_crop_entropy_weight, + process_focal_crop_edges_weight, + process_focal_crop_debug, + process_multicrop, + process_multicrop_mindim, + process_multicrop_maxdim, + process_multicrop_minarea, + process_multicrop_maxarea, + process_multicrop_objective, + process_multicrop_threshold, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + dummy_component, + train_embedding_name, + embedding_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + use_weight, + create_image_every, + save_embedding_every, + template_file, + save_image_with_stored_embedding, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + dummy_component, + train_hypernetwork_name, + hypernetwork_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + use_weight, + create_image_every, + save_embedding_every, + template_file, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + interrupt_preprocessing.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + def create_setting_component(key, section, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {t} for key {key}') + + elem_id = f"setting_{key}" + + if info.refresh is not None: + if is_quicksettings: + with gr.Row(elem_id=f'row_{elem_id}'): + with gr.Box(): + with gr.Row(elem_id=f'{elem_id}_row-collapse-one'): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + + with gr.Row(): + gr.Checkbox(label='', elem_id=f'{section}_add2quick_{elem_id}', value=True, interactive=True) + else: + with gr.Row(elem_id=f'row_{elem_id}'): + with gr.Box(): + with gr.Row(elem_id=f'{elem_id}_row-collapse-one'): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + with gr.Row(): + gr.Checkbox(label='', elem_id=f'{section}_add2quick_{elem_id}', value=False, interactive=True) + + else: + with gr.Row(elem_id=f'row_{elem_id}'): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + gr.Checkbox(label='', elem_id=f'{section}_add2quick_{elem_id}', value=is_quicksettings, interactive=True) + + return res + + loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) + + components = [] + component_dict = {} + shared.settings_components = component_dict + + script_callbacks.ui_settings_callback() + opts.reorder() + + def run_settings(*args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + + def run_settings_single(value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return get_value_for_setting(key), opts.dumpjson() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + with gr.Row(): + with gr.Column(): + settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + with gr.Column(): + restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") + + result = gr.HTML(elem_id="settings_result") + + quicksettings_names = opts.quicksettings_list + #quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} + + quicksettings_list = [] + + previous_section = None + current_tab = None + current_row = None + with gr.Tabs(elem_id="settings"): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + elem_id, text = item.section + + if current_tab is not None: + current_row.__exit__() + current_tab.__exit__() + + gr.Group() + current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text) + current_tab.__enter__() + current_row = gr.Column(variant='panel', elem_id="{}_settings_2img_settings".format(elem_id)) + current_row.__enter__() + + previous_section = item.section + + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) + else: + component = create_setting_component(k, elem_id) + component_dict[k] = component + components.append(component) + + if current_tab is not None: + current_row.__exit__() + current_tab.__exit__() + + with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"): + loadsave.create_ui() + + with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + with gr.Row(): + unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") + reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") + + with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): + gr.HTML(shared.html("licenses.html"), elem_id="licenses") + + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + + + def unload_sd_weights(): + modules.sd_models.unload_model_weights() + + def reload_sd_weights(): + modules.sd_models.reload_model_weights() + + unload_sd_model.click( + fn=unload_sd_weights, + inputs=[], + outputs=[] + ) + + reload_sd_model.click( + fn=reload_sd_weights, + inputs=[], + outputs=[] + ) + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + modules.scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + restart_gradio.click( + fn=shared.state.request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + interfaces = [ + (txt2img_interface, "txt2img", "txt2img"), + (img2img_interface, "img2img", "img2img"), + (extras_interface, "Extras", "extras"), + (pnginfo_interface, "PNG Info", "pnginfo"), + (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (train_interface, "Train", "train"), + ] + + interfaces += script_callbacks.ui_tabs_callback() + interfaces += [(settings_interface, "Settings", "settings")] + + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + + # shared.tab_names = [] + # for _interface, label, _ifid in interfaces: + # shared.tab_names.append(label) + + css = "" + + for cssfile in modules.scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + with open(cssfile, "r", encoding="utf8") as file: + css += file.read() + "\n" + + if os.path.exists(os.path.join(data_path, "user.css")): + with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file: + css += file.read() + "\n" + + + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: + with gr.Row(elem_id="header-top"): + gr.Row(elem_id="nav_menu") + gr.Row(elem_id="nav_menu_header_tabs") + + with gr.Row(elem_id="quicksettings"): + with gr.Row(elem_id="top_row_sd_model_checkpoint"): + sd_model_checkpoint = create_setting_component("sd_model_checkpoint", "sd", is_quicksettings=True) + component_dict['sd_model_checkpoint'] = sd_model_checkpoint + + with gr.Column(elem_id="quicksettings_overflow"): + with gr.Row(elem_id="quicksettings_actions"): + gr.Checkbox(label='', elem_id="quicksettings_draggable", interactive=True) + #ToolButton(elem_id="quicksettings_sort_asc", interactive=True) + #ToolButton(elem_id="quicksettings_sort_desc", interactive=True) + with gr.Column(elem_id="quicksettings_overflow_container") as qsettings_row: + for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): + if( str(k) != "sd_model_checkpoint"): + #with gr.Row(elem_id=f"quick_row_{k}") as qsettings_row: + component = create_setting_component(k, item.section[0], is_quicksettings=True) + component_dict[k] = component + gr.Row(elem_id="theme_menu") + gr.Row(elem_id="extra_networks_menu") + gr.Row(elem_id="quick_menu") + + parameters_copypaste.connect_paste_params_buttons() + # with gr.Tabs(elem_id="tabs") as tabs: + # for interface, label, ifid in interfaces: + # with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"): + # interface.render() + + with gr.Tabs(elem_id="tabs") as tabs: + tab_order = {k: i for i, k in enumerate(opts.ui_tab_order)} + sorted_interfaces = sorted(interfaces, key=lambda x: tab_order.get(x[1], 9999)) + + for interface, label, ifid in sorted_interfaces: + with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"): + interface.render() + + loadsave.add_block(interface, ifid) + + loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs) + + loadsave.setup_ui() + + if os.path.exists(os.path.join(script_path, "notification.mp3")): + gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) + + footer = shared.html("footer.html") + footer = footer.format(versions=versions_html()) + gr.HTML(footer) + + text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + settings_submit.click( + fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), + inputs=components, + outputs=[text_settings, result], + ) + + update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") + text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) + demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) + + button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) + button_set_checkpoint.click( + fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'), + _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", + inputs=[component_dict['sd_model_checkpoint'], dummy_component], + outputs=[component_dict['sd_model_checkpoint'], text_settings], + ) + sd_model_checkpoint.change( + fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'), + _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", + inputs=[component_dict['sd_model_checkpoint'], dummy_component], + outputs=[component_dict['sd_model_checkpoint'], text_settings], + ) + + + for _i, k, _item in quicksettings_list: + component = component_dict[k] + info = opts.data_labels[k] + + change_handler = component.release if hasattr(component, 'release') else component.change + change_handler( + fn=lambda value, k=k: run_settings_single(value, key=k), + inputs=[component], + outputs=[component, text_settings], + show_progress=info.refresh is not None, + ) + + update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") + text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) + demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) + + button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) + button_set_checkpoint.click( + fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'), + _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", + inputs=[component_dict['sd_model_checkpoint'], dummy_component], + outputs=[component_dict['sd_model_checkpoint'], text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] + + + def get_settings_values(): + return [get_value_for_setting(key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[component_dict[k] for k in component_keys], + queue=False, + ) + + def modelmerger(*args): + try: + results = modules.extras.run_modelmerger(*args) + except Exception as e: + print("Error loading/saving model file:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + modules.sd_models.list_models() # to remove the potentially missing models from the list + return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] + return results + + modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result]) + modelmerger_merge.click( + fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), + _js='modelmerger', + inputs=[ + dummy_component, + primary_model_name, + secondary_model_name, + tertiary_model_name, + interp_method, + interp_amount, + save_as_half, + custom_name, + checkpoint_format, + config_source, + bake_in_vae, + discard_weights, + save_metadata, + ], + outputs=[ + primary_model_name, + secondary_model_name, + tertiary_model_name, + component_dict['sd_model_checkpoint'], + modelmerger_result, + ] + ) + + loadsave.dump_defaults() + demo.ui_loadsave = loadsave + + # Required as a workaround for change() event not triggering when loading values from ui-config.json + interp_description.value = update_interp_description(interp_method.value) + + return demo + + +def webpath(fn): + if fn.startswith(script_path): + web_path = os.path.relpath(fn, script_path).replace('\\', '/') + else: + web_path = os.path.abspath(fn) + + return f'file={web_path}?{os.path.getmtime(fn)}' + + +def javascript_html(): + # Ensure localization is in `window` before scripts + head = f'\n' + #head = "" + + script_js = os.path.join(script_path, "script.js") + head += f'\n' + + for script in modules.scripts.list_scripts("javascript", ".js"): + head += f'\n' + + for script in modules.scripts.list_scripts("javascript", ".mjs"): + head += f'\n' + + if cmd_opts.theme: + head += f'\n' + + return head + + +def css_html(): + head = "" + def stylesheet(fn): + return f'' + + for cssfile in modules.scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + head += stylesheet(cssfile) + + if os.path.exists(os.path.join(data_path, "user.css")): + head += stylesheet(os.path.join(data_path, "user.css")) + + return head + + +def reload_javascript(): + js = javascript_html() + #css = css_html() + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace(b'', f'{js}'.encode("utf8")) + #res.body = res.body.replace(b'', f'{css}'.encode("utf8")) + res.init_headers() + return res + + gradio.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse + +def versions_html(): + import torch + import launch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = launch.commit_hash() + tag = launch.git_tag() + + if shared.xformers_available: + import xformers + xformers_version = xformers.__version__ + else: + xformers_version = "N/A" + + return f""" +
    +
  • os: {sys.platform}
  • +
  • python: {python_version}
  • +
  • torch: {getattr(torch, '__long_version__',torch.__version__)}
  • +
  • xformers: {xformers_version}
  • +
  • gradio: {gr.__version__}
  • +
  • commit: {tag}
  • +
  • checkpoint: N/A
  • +
+""" + +def setup_ui_api(app): + from pydantic import BaseModel, Field + from typing import List + + class QuicksettingsHint(BaseModel): + name: str = Field(title="Name of the quicksettings field") + label: str = Field(title="Label of the quicksettings field") + + def quicksettings_hint(): + return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()] + + app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint]) + + app.add_api_route("/internal/ping", lambda: {}, methods=["GET"]) + + app.add_api_route("/internal/profile-startup", lambda: timer.startup_record, methods=["GET"]) + + def download_sysinfo(attachment=False): + from fastapi.responses import PlainTextResponse + + text = sysinfo.get() + filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt" + + return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'}) + + app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"]) + app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"]) + + diff --git a/modules/ui_common.py b/modules/ui_common.py new file mode 100644 index 0000000000000000000000000000000000000000..81a302bf9040a1411b5aa7ac234b533592c73bc1 --- /dev/null +++ b/modules/ui_common.py @@ -0,0 +1,244 @@ +import json +import html +import os +import platform +import sys + +import gradio as gr +import subprocess as sp + +from modules import call_queue, shared +from modules.generation_parameters_copypaste import image_from_url_text +import modules.images +from modules.ui_components import ToolButton + + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 + + +def update_generation_info(generation_info, html_info, img_index): + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info, gr.update() + return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info, gr.update() + + +def plaintext_to_html(text, classname=None): + content = "
\n".join(html.escape(x) for x in text.split('\n')) + + return f"

{content}

" if classname else f"

{content}

" + + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = shared.opts.outdir_save + save_to_dirs = shared.opts.use_save_to_dirs_for_ui + extension: str = shared.opts.samples_format + start_index = 0 + only_one = False + + if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + only_one = True + images = [images[index]] + start_index = index + + os.makedirs(shared.opts.outdir_save, exist_ok=True) + + with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + p.batch_index = image_index-1 + fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0] + namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True) + zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]") + zip_filepath = os.path.join(path, f"{zip_filename}.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") + + +def create_output_panel(tabname, outdir): + from modules import shared + import modules.generation_parameters_copypaste as parameters_copypaste + + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel', elem_id=f"{tabname}_results"): + + with gr.Group(elem_id=f"{tabname}_gallery_container"): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4) + + generation_info = None + with gr.Column(): + with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"): + open_folder_button = gr.Button(folder_symbol, elem_id=f'open_folder_{tabname}', visible=not shared.cmd_opts.hide_ui_dir_config) + + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + + open_folder_button.click( + fn=lambda: open_folder(shared.opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') + + with gr.Accordion("Generation Info", open=False): + html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") + + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", + inputs=[generation_info, html_info, html_info], + outputs=[html_info, html_info], + show_progress=False, + ) + + save.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ], + show_progress=False, + ) + + save_zip.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + else: + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + paste_field_names = [] + if tabname == "txt2img": + paste_field_names = modules.scripts.scripts_txt2img.paste_field_names + elif tabname == "img2img": + paste_field_names = modules.scripts.scripts_img2img.paste_field_names + + for paste_tabname, paste_button in buttons.items(): + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery, + paste_field_names=paste_field_names + )) + + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + diff --git a/modules/ui_components.py b/modules/ui_components.py new file mode 100644 index 0000000000000000000000000000000000000000..8afdf076db8e7f633b3611cd23964bb69c52acc1 --- /dev/null +++ b/modules/ui_components.py @@ -0,0 +1,74 @@ +import gradio as gr + + +class FormComponent: + def get_expected_parent(self): + return gr.components.Form + + +gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent + + +class ToolButton(FormComponent, gr.Button): + """Small button with single emoji as text, fits inside gradio forms""" + + def __init__(self, *args, **kwargs): + classes = kwargs.pop("elem_classes", []) + super().__init__(*args, elem_classes=["tool", *classes], **kwargs) + + def get_block_name(self): + return "button" + + +class FormRow(FormComponent, gr.Row): + """Same as gr.Row but fits inside gradio forms""" + + def get_block_name(self): + return "row" + + +class FormColumn(FormComponent, gr.Column): + """Same as gr.Column but fits inside gradio forms""" + + def get_block_name(self): + return "column" + + +class FormGroup(FormComponent, gr.Group): + """Same as gr.Row but fits inside gradio forms""" + + def get_block_name(self): + return "group" + + +class FormHTML(FormComponent, gr.HTML): + """Same as gr.HTML but fits inside gradio forms""" + + def get_block_name(self): + return "html" + + +class FormColorPicker(FormComponent, gr.ColorPicker): + """Same as gr.ColorPicker but fits inside gradio forms""" + + def get_block_name(self): + return "colorpicker" + + +class DropdownMulti(FormComponent, gr.Dropdown): + """Same as gr.Dropdown but always multiselect""" + def __init__(self, **kwargs): + super().__init__(multiselect=True, **kwargs) + + def get_block_name(self): + return "dropdown" + + +class DropdownEditable(FormComponent, gr.Dropdown): + """Same as gr.Dropdown but allows editing value""" + def __init__(self, **kwargs): + super().__init__(allow_custom_value=True, **kwargs) + + def get_block_name(self): + return "dropdown" + diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..3e8c5f1d39b6748430d0dd911c09475bdfd614a8 --- /dev/null +++ b/modules/ui_extensions.py @@ -0,0 +1,651 @@ +import json +import os +import threading +import time +from datetime import datetime + +import git + +import gradio as gr +import html +import shutil +import errno + +from modules import extensions, shared, paths, config_states, errors, restart +from modules.paths_internal import config_states_dir +from modules.call_queue import wrap_gradio_gpu_call + +available_extensions = {"extensions": []} +STYLE_PRIMARY = ' style="color: var(--primary-400)"' + + +def check_access(): + assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags" + + +def apply_and_restart(disable_list, update_list, disable_all): + check_access() + + disabled = json.loads(disable_list) + assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" + + update = json.loads(update_list) + assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}" + + if update: + save_config_state("Backup (pre-update)") + + update = set(update) + + for ext in extensions.extensions: + if ext.name not in update: + continue + + try: + ext.fetch_and_reset_hard() + except Exception: + errors.report(f"Error getting updates for {ext.name}", exc_info=True) + + shared.opts.disabled_extensions = disabled + shared.opts.disable_all_extensions = disable_all + shared.opts.save(shared.config_filename) + + if restart.is_restartable(): + restart.restart_program() + else: + restart.stop_program() + + +def save_config_state(name): + current_config_state = config_states.get_config() + if not name: + name = "Config" + current_config_state["name"] = name + timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S') + filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json") + print(f"Saving backup of webui/extension state to {filename}.") + with open(filename, "w", encoding="utf-8") as f: + json.dump(current_config_state, f) + config_states.list_config_states() + new_value = next(iter(config_states.all_config_states.keys()), "Current") + new_choices = ["Current"] + list(config_states.all_config_states.keys()) + return gr.Dropdown.update(value=new_value, choices=new_choices), f"Saved current webui/extension state to \"{filename}\"" + + +def restore_config_state(confirmed, config_state_name, restore_type): + if config_state_name == "Current": + return "Select a config to restore from." + if not confirmed: + return "Cancelled." + + check_access() + + config_state = config_states.all_config_states[config_state_name] + + print(f"*** Restoring webui state from backup: {restore_type} ***") + + if restore_type == "extensions" or restore_type == "both": + shared.opts.restore_config_state_file = config_state["filepath"] + shared.opts.save(shared.config_filename) + + if restore_type == "webui" or restore_type == "both": + config_states.restore_webui_config(config_state) + + shared.state.request_restart() + + return "" + + +def check_updates(id_task, disable_list): + check_access() + + disabled = json.loads(disable_list) + assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" + + exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled] + shared.state.job_count = len(exts) + + for ext in exts: + shared.state.textinfo = ext.name + + try: + ext.check_updates() + except FileNotFoundError as e: + if 'FETCH_HEAD' not in str(e): + raise + except Exception: + errors.report(f"Error checking updates for {ext.name}", exc_info=True) + + shared.state.nextjob() + + return extension_table(), "" + + +def make_commit_link(commit_hash, remote, text=None): + if text is None: + text = commit_hash[:8] + if remote.startswith("https://github.com/"): + if remote.endswith(".git"): + remote = remote[:-4] + href = remote + "/commit/" + commit_hash + return f'{text}' + else: + return text + + +def extension_table(): + code = f""" + + + + + + + + + + + + + """ + + for ext in extensions.extensions: + ext: extensions.Extension + ext.read_info_from_repo() + + remote = f"""{html.escape("built-in" if ext.is_builtin else ext.remote or '')}""" + + if ext.can_update: + ext_status = f"""""" + else: + ext_status = ext.status + + style = "" + if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all": + style = STYLE_PRIMARY + + version_link = ext.version + if ext.commit_hash and ext.remote: + version_link = make_commit_link(ext.commit_hash, ext.remote, ext.version) + + code += f""" + + + + + + + {ext_status} + + """ + + code += """ + +
+ + Extension + URLBranchVersionDateUpdate
{html.escape(ext.name)}{remote}{ext.branch}{version_link}{time.asctime(time.gmtime(ext.commit_date))}
+ """ + + return code + + +def update_config_states_table(state_name): + if state_name == "Current": + config_state = config_states.get_config() + else: + config_state = config_states.all_config_states[state_name] + + config_name = config_state.get("name", "Config") + created_date = time.asctime(time.gmtime(config_state["created_at"])) + filepath = config_state.get("filepath", "") + + code = f"""""" + + webui_remote = config_state["webui"]["remote"] or "" + webui_branch = config_state["webui"]["branch"] + webui_commit_hash = config_state["webui"]["commit_hash"] or "" + webui_commit_date = config_state["webui"]["commit_date"] + if webui_commit_date: + webui_commit_date = time.asctime(time.gmtime(webui_commit_date)) + else: + webui_commit_date = "" + + remote = f"""{html.escape(webui_remote or '')}""" + commit_link = make_commit_link(webui_commit_hash, webui_remote) + date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date) + + current_webui = config_states.get_webui_config() + + style_remote = "" + style_branch = "" + style_commit = "" + if current_webui["remote"] != webui_remote: + style_remote = STYLE_PRIMARY + if current_webui["branch"] != webui_branch: + style_branch = STYLE_PRIMARY + if current_webui["commit_hash"] != webui_commit_hash: + style_commit = STYLE_PRIMARY + + code += f"""

Config Backup: {config_name}

+
Filepath: {filepath}
+
Created at: {created_date}
""" + + code += f"""

WebUI State

+ + + + + + + + + + + + + + + + + +
URLBranchCommitDate
{remote}{webui_branch}{commit_link}{date_link}
+ """ + + code += """

Extension State

+ + + + + + + + + + + + """ + + ext_map = {ext.name: ext for ext in extensions.extensions} + + for ext_name, ext_conf in config_state["extensions"].items(): + ext_remote = ext_conf["remote"] or "" + ext_branch = ext_conf["branch"] or "" + ext_enabled = ext_conf["enabled"] + ext_commit_hash = ext_conf["commit_hash"] or "" + ext_commit_date = ext_conf["commit_date"] + if ext_commit_date: + ext_commit_date = time.asctime(time.gmtime(ext_commit_date)) + else: + ext_commit_date = "" + + remote = f"""{html.escape(ext_remote or '')}""" + commit_link = make_commit_link(ext_commit_hash, ext_remote) + date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date) + + style_enabled = "" + style_remote = "" + style_branch = "" + style_commit = "" + if ext_name in ext_map: + current_ext = ext_map[ext_name] + current_ext.read_info_from_repo() + if current_ext.enabled != ext_enabled: + style_enabled = STYLE_PRIMARY + if current_ext.remote != ext_remote: + style_remote = STYLE_PRIMARY + if current_ext.branch != ext_branch: + style_branch = STYLE_PRIMARY + if current_ext.commit_hash != ext_commit_hash: + style_commit = STYLE_PRIMARY + + code += f""" + + + + + + + + """ + + code += """ + +
ExtensionURLBranchCommitDate
{html.escape(ext_name)}{remote}{ext_branch}{commit_link}{date_link}
+ """ + + return code + + +def normalize_git_url(url): + if url is None: + return "" + + url = url.replace(".git", "") + return url + + +def install_extension_from_url(dirname, url, branch_name=None): + check_access() + + if isinstance(dirname, str): + dirname = dirname.strip() + if isinstance(url, str): + url = url.strip() + + assert url, 'No URL specified' + + if dirname is None or dirname == "": + *parts, last_part = url.split('/') + last_part = normalize_git_url(last_part) + + dirname = last_part + + target_dir = os.path.join(extensions.extensions_dir, dirname) + assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' + + normalized_url = normalize_git_url(url) + if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url): + raise Exception(f'Extension with this URL is already installed: {url}') + + tmpdir = os.path.join(paths.data_path, "tmp", dirname) + + try: + shutil.rmtree(tmpdir, True) + if not branch_name: + # if no branch is specified, use the default branch + with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo: + repo.remote().fetch() + for submodule in repo.submodules: + submodule.update() + else: + with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo: + repo.remote().fetch() + for submodule in repo.submodules: + submodule.update() + try: + os.rename(tmpdir, target_dir) + except OSError as err: + if err.errno == errno.EXDEV: + # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems + # Since we can't use a rename, do the slower but more versitile shutil.move() + shutil.move(tmpdir, target_dir) + else: + # Something else, not enough free space, permissions, etc. rethrow it so that it gets handled. + raise err + + import launch + launch.run_extension_installer(target_dir) + + extensions.list_extensions() + return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")] + finally: + shutil.rmtree(tmpdir, True) + + +def install_extension_from_index(url, hide_tags, sort_column, filter_text): + ext_table, message = install_extension_from_url(None, url) + + code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text) + + return code, ext_table, message, '' + + +def refresh_available_extensions(url, hide_tags, sort_column): + global available_extensions + + import urllib.request + with urllib.request.urlopen(url) as response: + text = response.read() + + available_extensions = json.loads(text) + + code, tags = refresh_available_extensions_from_data(hide_tags, sort_column) + + return url, code, gr.CheckboxGroup.update(choices=tags), '', '' + + +def refresh_available_extensions_for_tags(hide_tags, sort_column, filter_text): + code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text) + + return code, '' + + +def search_extensions(filter_text, hide_tags, sort_column): + code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text) + + return code, '' + + +sort_ordering = [ + # (reverse, order_by_function) + (True, lambda x: x.get('added', 'z')), + (False, lambda x: x.get('added', 'z')), + (False, lambda x: x.get('name', 'z')), + (True, lambda x: x.get('name', 'z')), + (False, lambda x: 'z'), + (True, lambda x: x.get('commit_time', '')), + (True, lambda x: x.get('created_at', '')), + (True, lambda x: x.get('stars', 0)), +] + + +def get_date(info: dict, key): + try: + return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d") + except (ValueError, TypeError): + return '' + + +def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""): + extlist = available_extensions["extensions"] + installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} + + tags = available_extensions.get("tags", {}) + tags_to_hide = set(hide_tags) + hidden = 0 + + code = f""" + + + + + + + + + + """ + + sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0] + + for ext in sorted(extlist, key=sort_function, reverse=sort_reverse): + name = ext.get("name", "noname") + stars = int(ext.get("stars", 0)) + added = ext.get('added', 'unknown') + update_time = get_date(ext, 'commit_time') + create_time = get_date(ext, 'created_at') + url = ext.get("url", None) + description = ext.get("description", "") + extension_tags = ext.get("tags", []) + + if url is None: + continue + + existing = installed_extension_urls.get(normalize_git_url(url), None) + extension_tags = extension_tags + ["installed"] if existing else extension_tags + + if any(x for x in extension_tags if x in tags_to_hide): + hidden += 1 + continue + + if filter_text and filter_text.strip(): + if filter_text.lower() not in html.escape(name).lower() and filter_text.lower() not in html.escape(description).lower(): + hidden += 1 + continue + + install_code = f"""""" + + tags_text = ", ".join([f"{x}" for x in extension_tags]) + + code += f""" + + + + + + + """ + + for tag in [x for x in extension_tags if x not in tags]: + tags[tag] = tag + + code += """ + +
ExtensionDescriptionAction
{html.escape(name)}
{tags_text}
{html.escape(description)}

+ Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}stars: {stars}

{install_code}
+ """ + + if hidden > 0: + code += f"

Extension hidden: {hidden}

" + + return code, list(tags) + + +def preload_extensions_git_metadata(): + for extension in extensions.extensions: + extension.read_info_from_repo() + + +def create_ui(): + import modules.ui + + config_states.list_config_states() + + threading.Thread(target=preload_extensions_git_metadata).start() + + with gr.Blocks(analytics_enabled=False) as ui: + with gr.Tabs(elem_id="tabs_extensions"): + with gr.TabItem("Installed", id="installed"): + + with gr.Row(elem_id="extensions_installed_top"): + apply_label = ("Apply and restart UI" if restart.is_restartable() else "Apply and quit") + apply = gr.Button(value=apply_label, variant="primary") + check = gr.Button(value="Check for updates") + extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all") + extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) + extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False) + + html = "" + if shared.opts.disable_all_extensions != "none": + html = """ + + "Disable all extensions" was set, change it to "none" to load all extensions again + + """ + info = gr.HTML(html) + extensions_table = gr.HTML('Loading...') + ui.load(fn=extension_table, inputs=[], outputs=[extensions_table]) + + apply.click( + fn=apply_and_restart, + _js="extensions_apply", + inputs=[extensions_disabled_list, extensions_update_list, extensions_disable_all], + outputs=[], + ) + + check.click( + fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]), + _js="extensions_check", + inputs=[info, extensions_disabled_list], + outputs=[extensions_table, info], + ) + + with gr.TabItem("Available", id="available"): + with gr.Row(): + refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary") + extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json") + available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False) + extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) + install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) + + with gr.Row(): + hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) + sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index") + + with gr.Row(): + search_extensions_text = gr.Text(label="Search").style(container=False) + + install_result = gr.HTML() + available_extensions_table = gr.HTML() + + refresh_available_extensions_button.click( + fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]), + inputs=[available_extensions_index, hide_tags, sort_column], + outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result], + ) + + install_extension_button.click( + fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), + inputs=[extension_to_install, hide_tags, sort_column, search_extensions_text], + outputs=[available_extensions_table, extensions_table, install_result], + ) + + search_extensions_text.change( + fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update()]), + inputs=[search_extensions_text, hide_tags, sort_column], + outputs=[available_extensions_table, install_result], + ) + + hide_tags.change( + fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[hide_tags, sort_column, search_extensions_text], + outputs=[available_extensions_table, install_result] + ) + + sort_column.change( + fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[hide_tags, sort_column, search_extensions_text], + outputs=[available_extensions_table, install_result] + ) + + with gr.TabItem("Install from URL", id="install_from_url"): + install_url = gr.Text(label="URL for extension's git repository") + install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch") + install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") + install_button = gr.Button(value="Install", variant="primary") + install_result = gr.HTML(elem_id="extension_install_result") + + install_button.click( + fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]), + inputs=[install_dirname, install_url, install_branch], + outputs=[install_url, extensions_table, install_result], + ) + + with gr.TabItem("Backup/Restore"): + with gr.Row(elem_id="extensions_backup_top_row"): + config_states_list = gr.Dropdown(label="Saved Configs", elem_id="extension_backup_saved_configs", value="Current", choices=["Current"] + list(config_states.all_config_states.keys())) + modules.ui.create_refresh_button(config_states_list, config_states.list_config_states, lambda: {"choices": ["Current"] + list(config_states.all_config_states.keys())}, "refresh_config_states") + config_restore_type = gr.Radio(label="State to restore", choices=["extensions", "webui", "both"], value="extensions", elem_id="extension_backup_restore_type") + config_restore_button = gr.Button(value="Restore Selected Config", variant="primary", elem_id="extension_backup_restore") + with gr.Row(elem_id="extensions_backup_top_row2"): + config_save_name = gr.Textbox("", placeholder="Config Name", show_label=False) + config_save_button = gr.Button(value="Save Current Config") + + config_states_info = gr.HTML("") + config_states_table = gr.HTML("Loading...") + ui.load(fn=update_config_states_table, inputs=[config_states_list], outputs=[config_states_table]) + + config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info]) + + dummy_component = gr.Label(visible=False) + config_restore_button.click(fn=restore_config_state, _js="config_state_confirm_restore", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info]) + + config_states_list.change( + fn=update_config_states_table, + inputs=[config_states_list], + outputs=[config_states_table], + ) + + + return ui diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py new file mode 100644 index 0000000000000000000000000000000000000000..da4a9e5a85fb2466cc94f2c459b123c634d104fd --- /dev/null +++ b/modules/ui_extra_networks.py @@ -0,0 +1,496 @@ +import os.path +import urllib.parse +from pathlib import Path + +from modules import shared, ui_extra_networks_user_metadata, errors +from modules.images import read_info_from_image, save_image_with_geninfo +from modules.ui import up_down_symbol +import gradio as gr +import json +import html + +from modules.ui_components import ToolButton +from fastapi.exceptions import HTTPException + +from modules.generation_parameters_copypaste import image_from_url_text +from modules.ui_components import ToolButton + +extra_pages = [] +allowed_dirs = set() +refresh_symbol = '\U0001f504' # 🔄 +#clear_symbol = '\U0001F5D9' # 🗙 + +def register_page(page): + """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" + + extra_pages.append(page) + allowed_dirs.clear() + allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], []))) + + +def fetch_file(filename: str = ""): + from starlette.responses import FileResponse + + if not os.path.isfile(filename): + raise HTTPException(status_code=404, detail="File not found") + + if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs): + raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + + ext = os.path.splitext(filename)[1].lower() + if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"): + raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.") + + # would profit from returning 304 + return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) + + +def get_metadata(page: str = "", item: str = ""): + from starlette.responses import JSONResponse + + page = next(iter([x for x in extra_pages if x.name == page]), None) + if page is None: + return JSONResponse({}) + + metadata = page.metadata.get(item) + if metadata is None: + return JSONResponse({}) + + return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)}) + + +def get_single_card(page: str = "", tabname: str = "", name: str = ""): + from starlette.responses import JSONResponse + + page = next(iter([x for x in extra_pages if x.name == page]), None) + + try: + item = page.create_item(name, enable_filter=False) + page.items[name] = item + except Exception as e: + errors.display(e, "creating item for extra network") + item = page.items.get(name) + + page.read_user_metadata(item) + item_html = page.create_html_for_item(item, tabname) + + return JSONResponse({"html": item_html}) + + +def add_pages_to_demo(app): + app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) + app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"]) + app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"]) + + +def quote_js(s): + s = s.replace('\\', '\\\\') + s = s.replace('"', '\\"') + return f'"{s}"' + + +class ExtraNetworksPage: + def __init__(self, title): + self.title = title + self.name = title.lower() + self.id_page = self.name.replace(" ", "_") + self.card_page = shared.html("extra-networks-card.html") + self.allow_negative_prompt = False + self.metadata = {} + self.items = {} + + def refresh(self): + pass + + def read_user_metadata(self, item): + filename = item.get("filename", None) + basename, ext = os.path.splitext(filename) + metadata_filename = basename + '.json' + + metadata = {} + try: + if os.path.isfile(metadata_filename): + with open(metadata_filename, "r", encoding="utf8") as file: + metadata = json.load(file) + except Exception as e: + errors.display(e, f"reading extra network user metadata from {metadata_filename}") + + desc = metadata.get("description", None) + if desc is not None: + item["description"] = desc + + item["user_metadata"] = metadata + + def link_preview(self, filename): + quoted_filename = urllib.parse.quote(filename.replace('\\', '/')) + mtime = os.path.getmtime(filename) + return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}" + + def search_terms_from_path(self, filename, possible_directories=None): + abspath = os.path.abspath(filename) + + for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()): + parentdir = os.path.abspath(parentdir) + if abspath.startswith(parentdir): + return abspath[len(parentdir):].replace('\\', '/') + + return "" + + + def create_html(self, tabname): + view = "cards" #shared.opts.extra_networks_default_view + items_html = '' + + self.metadata = {} + + subdirs = {} + for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: + for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])): + for dirname in sorted(dirs, key=shared.natural_sort_key): + x = os.path.join(root, dirname) + + if not os.path.isdir(x): + continue + + subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") + while subdir.startswith("/"): + subdir = subdir[1:] + + is_empty = len(os.listdir(x)) == 0 + if not is_empty and not subdir.endswith("/"): + subdir = subdir + "/" + + if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories: + continue + + subdirs[subdir] = 1 + + if subdirs: + subdirs = {"": 1, **subdirs} + + +# + subdirs_html = "".join([f""" + +""" for subdir in subdirs]) + + self.items = {x["name"]: x for x in self.list_items()} + for item in self.items.values(): + metadata = item.get("metadata") + if metadata: + self.metadata[item["name"]] = metadata + + if "user_metadata" not in item: + self.read_user_metadata(item) + + items_html += self.create_html_for_item(item, tabname) + + if items_html == '': + dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) + items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) + + self_name_id = self.name.replace(" ", "_") + +# + res = f""" +
    +{subdirs_html} +
    +
    +{items_html} +
    +""" + + return res + + def create_item(self, name, index=None): + raise NotImplementedError() + + def list_items(self): + raise NotImplementedError() + + def allowed_directories_for_previews(self): + return [] + + def create_html_for_item(self, item, tabname): + """ + Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown. + """ + + preview = item.get("preview", None) + + onclick = item.get("onclick", None) + if onclick is None: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + + #height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' + #width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' + background_image = f'' if preview else '' + + metadata_button = "" + metadata = item.get("metadata") + if metadata: + metadata_button = f"" + + edit_button = f"
    " + + local_path = "" + filename = item.get("filename", "") + + for reldir in self.allowed_directories_for_previews(): + absdir = os.path.abspath(reldir) + if filename.startswith(absdir): + local_path = filename[len(absdir):] + + # if this is true, the item must not be shown in the default view, and must instead only be + # shown when searching for it + + if shared.opts.extra_networks_hidden_models == "Always": + search_only = False + else: + search_only = "/." in local_path or "\\." in local_path + + if search_only and shared.opts.extra_networks_hidden_models == "Never": + return "" + + sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip() + + args = { + #"background_image": background_image, + #"style": f"'display: none; {height}{width}'", + "preview_image": html.escape(preview) if preview else './file=html/card-no-preview.png', + "prompt": item.get("prompt", None), + "tabname": quote_js(tabname), + "local_preview": quote_js(item["local_preview"]), + "name": item["name"], + "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), + "card_clicked": onclick, + "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', + "search_term": item.get("search_term", ""), + "metadata_button": metadata_button, + "edit_button": edit_button, + "search_only": " search_only" if search_only else "", + "sort_keys": sort_keys, + } + + return self.card_page.format(**args) + + def get_sort_keys(self, path): + """ + List of default keys used for sorting in the UI. + """ + pth = Path(path) + stat = pth.stat() + return { + "date_created": int(stat.st_ctime or 0), + "date_modified": int(stat.st_mtime or 0), + "name": pth.name.lower(), + } + + def find_preview(self, path): + """ + Find a preview PNG for a given path (without extension) and call link_preview on it. + """ + + preview_extensions = ["png", "jpg", "jpeg", "webp"] + if shared.opts.samples_format not in preview_extensions: + preview_extensions.append(shared.opts.samples_format) + + # file_name = os.path.basename(path) + # location = os.path.dirname(path) + # preview_path = location + "/preview/" + file_name + # potential_files = sum([[path + "." + ext, path + ".preview." + ext, preview_path + "." + ext, preview_path + ".preview." + ext] for ext in preview_extensions], []) + + potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], []) + + for file in potential_files: + if os.path.isfile(file): + return self.link_preview(file) + + for file in potential_files: + if os.path.isfile(file): + return self.link_preview(file) + + return None + + def find_description(self, path): + """ + Find and read a description file for a given path (without extension). + """ + for file in [f"{path}.txt", f"{path}.description.txt"]: + try: + with open(file, "r", encoding="utf-8", errors="replace") as f: + return f.read() + except OSError: + pass + return None + + def create_user_metadata_editor(self, ui, tabname): + return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self) + + +def initialize(): + extra_pages.clear() + + +def register_default_pages(): + from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion + from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks + from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints + register_page(ExtraNetworksPageTextualInversion()) + register_page(ExtraNetworksPageHypernetworks()) + register_page(ExtraNetworksPageCheckpoints()) + + +class ExtraNetworksUi: + def __init__(self): + self.pages = None + """gradio HTML components related to extra networks' pages""" + + self.page_contents = None + """HTML content of the above; empty initially, filled when extra pages have to be shown""" + + self.stored_extra_pages = None + + self.button_save_preview = None + self.preview_target_filename = None + + self.tabname = None + + +def pages_in_preferred_order(pages): + tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")] + + def tab_name_score(name): + name = name.lower() + for i, possible_match in enumerate(tab_order): + if possible_match in name: + return i + + return len(pages) + + tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)} + + return sorted(pages, key=lambda x: tab_scores[x.name]) + + +def create_ui(container, button, tabname): + ui = ExtraNetworksUi() + ui.pages = [] + ui.pages_contents = [] + ui.user_metadata_editors = [] + ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) + ui.tabname = tabname + + with gr.Accordion("Extra Networks", open=True): + + with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: + for page in ui.stored_extra_pages: + page_id = page.title.lower().replace(" ", "_") + with gr.Tab(page.title, id=page_id): + #elem_id = f"{tabname}_{page_id}_cards_html" + #page_elem = gr.HTML('Loading...', elem_id=elem_id) + page_elem = gr.HTML(page.create_html(ui.tabname)) + ui.pages.append(page_elem) + + page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) + + editor = page.create_user_metadata_editor(ui, tabname) + editor.create_ui() + ui.user_metadata_editors.append(editor) + + gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) + #gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True) + #gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder") + button_refresh = ToolButton(value=refresh_symbol, elem_id=tabname+"_extra_refresh") + + ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) + ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + + + + def toggle_visibility(is_visible): + is_visible = not is_visible + + return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")) + + def fill_tabs(is_empty): + """Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time.""" + + if not ui.pages_contents: + refresh() + + if is_empty: + return True, *ui.pages_contents + + return True, *[gr.update() for _ in ui.pages_contents] + + state_visible = gr.State(value=False) + button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False) + + state_empty = gr.State(value=True) + button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False) + + def refresh(): + for pg in ui.stored_extra_pages: + pg.refresh() + + ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] + + return ui.pages_contents + + button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) + + return ui + + +def path_is_parent(parent_path, child_path): + parent_path = os.path.abspath(parent_path) + child_path = os.path.abspath(child_path) + + return child_path.startswith(parent_path) + + +def setup_ui(ui, gallery): + def save_preview(index, images, filename): + # this function is here for backwards compatibility and likely will be removed soon + + if len(images) == 0: + print("There is no image in gallery to save as a preview.") + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + index = int(index) + index = 0 if index < 0 else index + index = len(images) - 1 if index >= len(images) else index + + img_info = images[index if index >= 0 else 0] + image = image_from_url_text(img_info) + geninfo, items = read_info_from_image(image) + + is_allowed = False + for extra_page in ui.stored_extra_pages: + if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()): + is_allowed = True + break + + assert is_allowed, f'writing to {filename} is not allowed' + + save_image_with_geninfo(image, geninfo, filename) + + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + ui.button_save_preview.click( + fn=save_preview, + _js="function(x, y, z){return [selected_gallery_index(), y, z]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], + outputs=[*ui.pages] + ) + + for editor in ui.user_metadata_editors: + editor.setup_ui(gallery) + + diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f41bdd71ad1cdddc3f9c847eaf6dfbf6b6c491 --- /dev/null +++ b/modules/ui_extra_networks_checkpoints.py @@ -0,0 +1,35 @@ +import html +import os + +from modules import shared, ui_extra_networks, sd_models +from modules.ui_extra_networks import quote_js + + +class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Checkpoints') + + def refresh(self): + shared.refresh_checkpoints() + + def create_item(self, name, index=None): + checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) + path, ext = os.path.splitext(checkpoint.filename) + return { + "name": checkpoint.name_for_extra, + "filename": checkpoint.filename, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), + "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', + "local_preview": f"{path}.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, + } + + def list_items(self): + for index, name in enumerate(sd_models.checkpoints_list): + yield self.create_item(name, index) + + def allowed_directories_for_previews(self): + return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] + diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py new file mode 100644 index 0000000000000000000000000000000000000000..b53db2bf876793a8df9e2d284c00520022960d7e --- /dev/null +++ b/modules/ui_extra_networks_hypernets.py @@ -0,0 +1,35 @@ +import os + +from modules import shared, ui_extra_networks +from modules.ui_extra_networks import quote_js + + +class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Hypernetworks') + + def refresh(self): + shared.reload_hypernetworks() + + def create_item(self, name, index=None): + full_path = shared.hypernetworks[name] + path, ext = os.path.splitext(full_path) + + return { + "name": name, + "filename": full_path, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(path), + "prompt": quote_js(f""), + "local_preview": f"{path}.preview.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, + } + + def list_items(self): + for index, name in enumerate(shared.hypernetworks): + yield self.create_item(name, index) + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.hypernetwork_dir] + diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py new file mode 100644 index 0000000000000000000000000000000000000000..9aee6f4e73b89fd86b2bb2191232f535f1409ca3 --- /dev/null +++ b/modules/ui_extra_networks_textual_inversion.py @@ -0,0 +1,35 @@ +import os + +from modules import ui_extra_networks, sd_hijack, shared +from modules.ui_extra_networks import quote_js + + +class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Textual Inversion') + self.allow_negative_prompt = True + + def refresh(self): + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + + def create_item(self, name, index=None): + embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) + + path, ext = os.path.splitext(embedding.filename) + return { + "name": name, + "filename": embedding.filename, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(embedding.filename), + "prompt": quote_js(embedding.name), + "local_preview": f"{path}.preview.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, + } + + def list_items(self): + for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): + yield self.create_item(name, index) + + def allowed_directories_for_previews(self): + return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..268249d95eb263d4928e8806143e8f8e18c83a98 --- /dev/null +++ b/modules/ui_extra_networks_user_metadata.py @@ -0,0 +1,195 @@ +import datetime +import html +import json +import os.path + +import gradio as gr + +from modules import generation_parameters_copypaste, images, sysinfo, errors + + +class UserMetadataEditor: + + def __init__(self, ui, tabname, page): + self.ui = ui + self.tabname = tabname + self.page = page + self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata" + + self.box = None + + self.edit_name_input = None + self.button_edit = None + + self.edit_name = None + self.edit_description = None + self.edit_notes = None + self.html_filedata = None + self.html_preview = None + self.html_status = None + + self.button_cancel = None + self.button_replace_preview = None + self.button_save = None + + def get_user_metadata(self, name): + item = self.page.items.get(name, {}) + + user_metadata = item.get('user_metadata', None) + if user_metadata is None: + user_metadata = {} + item['user_metadata'] = user_metadata + + return user_metadata + + def create_extra_default_items_in_left_column(self): + pass + + def create_default_editor_elems(self): + with gr.Row(): + with gr.Column(scale=2): + self.edit_name = gr.HTML(elem_classes="extra-network-name") + self.edit_description = gr.Textbox(label="Description", lines=4) + self.html_filedata = gr.HTML() + + self.create_extra_default_items_in_left_column() + + with gr.Column(scale=1, min_width=0): + self.html_preview = gr.HTML() + + def create_default_buttons(self): + + with gr.Row(elem_classes="edit-user-metadata-buttons"): + self.button_cancel = gr.Button('Cancel') + self.button_replace_preview = gr.Button('Replace preview', variant='primary') + self.button_save = gr.Button('Save', variant='primary') + + self.html_status = gr.HTML(elem_classes="edit-user-metadata-status") + + self.button_cancel.click(fn=None, _js="closePopup") + + def get_card_html(self, name): + item = self.page.items.get(name, {}) + + preview_url = item.get("preview", None) + + if not preview_url: + filename, _ = os.path.splitext(item["filename"]) + preview_url = self.page.find_preview(filename) + item["preview"] = preview_url + + if preview_url: + preview = f''' +
    + +
    + ''' + else: + preview = "
    " + + return preview + + def get_metadata_table(self, name): + item = self.page.items.get(name, {}) + try: + filename = item["filename"] + + stats = os.stat(filename) + params = [ + ('File size: ', sysinfo.pretty_bytes(stats.st_size)), + ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), + ] + + return params + except Exception as e: + errors.display(e, f"reading info for {name}") + return [] + + def put_values_into_components(self, name): + user_metadata = self.get_user_metadata(name) + + try: + params = self.get_metadata_table(name) + except Exception as e: + errors.display(e, f"reading metadata info for {name}") + params = [] + + table = '' + "".join(f"" for name, value in params) + '' + + return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '') + + def write_user_metadata(self, name, metadata): + item = self.page.items.get(name, {}) + filename = item.get("filename", None) + basename, ext = os.path.splitext(filename) + + with open(basename + '.json', "w", encoding="utf8") as file: + json.dump(metadata, file) + + def save_user_metadata(self, name, desc, notes): + user_metadata = self.get_user_metadata(name) + user_metadata["description"] = desc + user_metadata["notes"] = notes + + self.write_user_metadata(name, user_metadata) + + def setup_save_handler(self, button, func, components): + button\ + .click(fn=func, inputs=[self.edit_name_input, *components], outputs=[])\ + .then(fn=None, _js="function(name){closePopup(); extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", inputs=[self.edit_name_input], outputs=[]) + + def create_editor(self): + self.create_default_editor_elems() + + self.edit_notes = gr.TextArea(label='Notes', lines=4) + + self.create_default_buttons() + + self.button_edit\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview, self.edit_notes])\ + .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) + + self.setup_save_handler(self.button_save, self.save_user_metadata, [self.edit_description, self.edit_notes]) + + def create_ui(self): + with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box: + self.box = box + + self.edit_name_input = gr.Textbox("Edit user metadata card id", visible=False, elem_id=f"{self.id_part}_name") + self.button_edit = gr.Button("Edit user metadata", visible=False, elem_id=f"{self.id_part}_button") + + self.create_editor() + + def save_preview(self, index, gallery, name): + if len(gallery) == 0: + return self.get_card_html(name), "There is no image in gallery to save as a preview." + + item = self.page.items.get(name, {}) + + index = int(index) + index = 0 if index < 0 else index + index = len(gallery) - 1 if index >= len(gallery) else index + + img_info = gallery[index if index >= 0 else 0] + image = generation_parameters_copypaste.image_from_url_text(img_info) + geninfo, items = images.read_info_from_image(image) + + images.save_image_with_geninfo(image, geninfo, item["local_preview"]) + + return self.get_card_html(name), '' + + def setup_ui(self, gallery): + self.button_replace_preview.click( + fn=self.save_preview, + _js="function(x, y, z){return [selected_gallery_index(), y, z]}", + inputs=[self.edit_name_input, gallery, self.edit_name_input], + outputs=[self.html_preview, self.html_status] + ).then( + fn=None, + _js="function(name){extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", + inputs=[self.edit_name_input], + outputs=[] + ) + + + diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..4e761aa54828d44d9479fa626fa835c225c66b1c --- /dev/null +++ b/modules/ui_gradio_extensions.py @@ -0,0 +1,69 @@ +import os +import gradio as gr + +from modules import localization, shared, scripts +from modules.paths import script_path, data_path + + +def webpath(fn): + if fn.startswith(script_path): + web_path = os.path.relpath(fn, script_path).replace('\\', '/') + else: + web_path = os.path.abspath(fn) + + return f'file={web_path}?{os.path.getmtime(fn)}' + + +def javascript_html(): + # Ensure localization is in `window` before scripts + head = f'\n' + + script_js = os.path.join(script_path, "script.js") + head += f'\n' + + for script in scripts.list_scripts("javascript", ".js"): + head += f'\n' + + for script in scripts.list_scripts("javascript", ".mjs"): + head += f'\n' + + if shared.cmd_opts.theme: + head += f'\n' + + return head + + +def css_html(): + head = "" + + def stylesheet(fn): + return f'' + + for cssfile in scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + head += stylesheet(cssfile) + + if os.path.exists(os.path.join(data_path, "user.css")): + head += stylesheet(os.path.join(data_path, "user.css")) + + return head + + +def reload_javascript(): + js = javascript_html() + css = css_html() + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace(b'', f'{js}'.encode("utf8")) + res.body = res.body.replace(b'', f'{css}'.encode("utf8")) + res.init_headers() + return res + + gr.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py new file mode 100644 index 0000000000000000000000000000000000000000..1e21eb595a1fd76fee29af682c99ab5531a87823 --- /dev/null +++ b/modules/ui_loadsave.py @@ -0,0 +1,210 @@ +import json +import os + +import gradio as gr + +from modules import errors +from modules.ui_components import ToolButton + + +class UiLoadsave: + """allows saving and restorig default values for gradio components""" + + def __init__(self, filename): + self.filename = filename + self.ui_settings = {} + self.component_mapping = {} + self.error_loading = False + self.finalized_ui = False + + self.ui_defaults_view = None + self.ui_defaults_apply = None + self.ui_defaults_review = None + + try: + if os.path.exists(self.filename): + self.ui_settings = self.read_from_file() + except Exception as e: + self.error_loading = True + errors.display(e, "loading settings") + + def add_component(self, path, x): + """adds component to the registry of tracked components""" + + assert not self.finalized_ui + + def apply_field(obj, field, condition=None, init_field=None): + key = f"{path}/{field}" + + if getattr(obj, 'custom_script_source', None) is not None: + key = f"customscript/{obj.custom_script_source}/{key}" + + if getattr(obj, 'do_not_save_to_config', False): + return + + saved_value = self.ui_settings.get(key, None) + if saved_value is None: + self.ui_settings[key] = getattr(obj, field) + elif condition and not condition(saved_value): + pass + else: + setattr(obj, field, saved_value) + if init_field is not None: + init_field(saved_value) + + if field == 'value' and key not in self.component_mapping: + self.component_mapping[key] = x + + if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible: + apply_field(x, 'visible') + + if type(x) == gr.Slider: + apply_field(x, 'value') + apply_field(x, 'minimum') + apply_field(x, 'maximum') + apply_field(x, 'step') + + if type(x) == gr.Radio: + apply_field(x, 'value', lambda val: val in x.choices) + + if type(x) == gr.Checkbox: + apply_field(x, 'value') + + if type(x) == gr.Textbox: + apply_field(x, 'value') + + if type(x) == gr.Number: + apply_field(x, 'value') + + if type(x) == gr.Dropdown: + def check_dropdown(val): + if getattr(x, 'multiselect', False): + return all(value in x.choices for value in val) + else: + return val in x.choices + + apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None)) + + def check_tab_id(tab_id): + tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children)) + if type(tab_id) == str: + tab_ids = [t.id for t in tab_items] + return tab_id in tab_ids + elif type(tab_id) == int: + return 0 <= tab_id < len(tab_items) + else: + return False + + if type(x) == gr.Tabs: + apply_field(x, 'selected', check_tab_id) + + def add_block(self, x, path=""): + """adds all components inside a gradio block x to the registry of tracked components""" + + if hasattr(x, 'children'): + if isinstance(x, gr.Tabs) and x.elem_id is not None: + # Tabs element can't have a label, have to use elem_id instead + self.add_component(f"{path}/Tabs@{x.elem_id}", x) + for c in x.children: + self.add_block(c, path) + elif x.label is not None: + self.add_component(f"{path}/{x.label}", x) + elif isinstance(x, gr.Button) and x.value is not None: + self.add_component(f"{path}/{x.value}", x) + + def read_from_file(self): + with open(self.filename, "r", encoding="utf8") as file: + return json.load(file) + + def write_to_file(self, current_ui_settings): + with open(self.filename, "w", encoding="utf8") as file: + json.dump(current_ui_settings, file, indent=4) + + def dump_defaults(self): + """saves default values to a file unless tjhe file is present and there was an error loading default values at start""" + + if self.error_loading and os.path.exists(self.filename): + return + + self.write_to_file(self.ui_settings) + + def iter_changes(self, current_ui_settings, values): + """ + given a dictionary with defaults from a file and current values from gradio elements, returns + an iterator over tuples of values that are not the same between the file and the current; + tuple contents are: path, old value, new value + """ + + for (path, component), new_value in zip(self.component_mapping.items(), values): + old_value = current_ui_settings.get(path) + + choices = getattr(component, 'choices', None) + if isinstance(new_value, int) and choices: + if new_value >= len(choices): + continue + + new_value = choices[new_value] + + if new_value == old_value: + continue + + if old_value is None and new_value == '' or new_value == []: + continue + + yield path, old_value, new_value + + def ui_view(self, *values): + text = [""] + + for path, old_value, new_value in self.iter_changes(self.read_from_file(), values): + if old_value is None: + old_value = "None" + + text.append(f"") + + if len(text) == 1: + text.append("") + + text.append("") + return "".join(text) + + def ui_apply(self, *values): + num_changed = 0 + + current_ui_settings = self.read_from_file() + + for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values): + num_changed += 1 + current_ui_settings[path] = new_value + + if num_changed == 0: + return "No changes." + + self.write_to_file(current_ui_settings) + + return f"Wrote {num_changed} changes." + + def create_ui(self): + """creates ui elements for editing defaults UI, without adding any logic to them""" + + gr.HTML( + f"This page allows you to change default values in UI elements on other tabs.
    " + f"Make your changes, press 'View changes' to review the changed default values,
    " + f"then press 'Apply' to write them to {self.filename}.
    " + f"New defaults will apply after you restart the UI.
    " + ) + + with gr.Row(): + self.ui_defaults_view = gr.Button(value='View changes', elem_id="ui_defaults_view", variant="secondary") + self.ui_defaults_apply = gr.Button(value='Apply', elem_id="ui_defaults_apply", variant="primary") + + self.ui_defaults_review = gr.HTML("") + + def setup_ui(self): + """adds logic to elements created with create_ui; all add_block class must be made before this""" + + assert not self.finalized_ui + self.finalized_ui = True + + self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review]) + self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review]) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..c51314165a7de6166e2959bc97fe6cd3f9cd3020 --- /dev/null +++ b/modules/ui_postprocessing.py @@ -0,0 +1,57 @@ +import gradio as gr +from modules import scripts, shared, ui_common, postprocessing, call_queue +import modules.generation_parameters_copypaste as parameters_copypaste + +def create_ui(): + tab_index = gr.State(value=0) + gr.Row(elem_id="extras_2img_prompt_image", visible=False) + with gr.Row(): + with gr.Column(elem_id="extras_2img_results"): + result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras_2img", shared.opts.outdir_extras_samples) + gr.Row(elem_id="extras_2img_splitter") + with gr.Column(variant='panel', elem_id="extras_2img_settings"): + submit = gr.Button('Upscale', elem_id="extras_generate", variant='primary') + with gr.Column(elem_id="extras_2img_settings_scroll"): + with gr.Accordion("Image Source", elem_id="extras_accordion", open=True): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single: + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + + with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch: + image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch") + + with gr.TabItem('Batch from Directory', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir: + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") + + script_inputs = scripts.scripts_postproc.setup_ui() + + tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) + tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) + tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) + + submit.click( + fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + inputs=[ + tab_index, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + *script_inputs + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=scripts.scripts_postproc.image_changed, + inputs=[], outputs=[] + ) diff --git a/modules/ui_settings.py b/modules/ui_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..3f04839f53665c01d616eb7113951c200b3e8a01 --- /dev/null +++ b/modules/ui_settings.py @@ -0,0 +1,296 @@ +import gradio as gr + +from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo +from modules.call_queue import wrap_gradio_call +from modules.shared import opts +from modules.ui_components import FormRow +from modules.ui_gradio_extensions import reload_javascript + + +def get_value_for_setting(key): + value = getattr(opts, key) + + info = opts.data_labels[key] + args = info.component_args() if callable(info.component_args) else info.component_args or {} + args = {k: v for k, v in args.items() if k not in {'precision'}} + + return gr.update(value=value, **args) + + +def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {t} for key {key}') + + elem_id = f"setting_{key}" + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") + else: + with FormRow(): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + return res + + +class UiSettings: + submit = None + result = None + interface = None + components = None + component_dict = None + dummy_component = None + quicksettings_list = None + quicksettings_names = None + text_settings = None + + def run_settings(self, *args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, self.components): + assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, self.components): + if comp == self.dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if changed else ""}{", ".join(changed)}.' + + def run_settings_single(self, value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return get_value_for_setting(key), opts.dumpjson() + + def create_ui(self, loadsave, dummy_component): + self.components = [] + self.component_dict = {} + self.dummy_component = dummy_component + + shared.settings_components = self.component_dict + + script_callbacks.ui_settings_callback() + opts.reorder() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + with gr.Row(): + with gr.Column(scale=6): + self.submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + with gr.Column(): + restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") + + self.result = gr.HTML(elem_id="settings_result") + + self.quicksettings_names = opts.quicksettings_list + self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'} + + self.quicksettings_list = [] + + previous_section = None + current_tab = None + current_row = None + with gr.Tabs(elem_id="settings"): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + elem_id, text = item.section + + if current_tab is not None: + current_row.__exit__() + current_tab.__exit__() + + gr.Group() + current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text) + current_tab.__enter__() + current_row = gr.Column(variant='compact') + current_row.__enter__() + + previous_section = item.section + + if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings: + self.quicksettings_list.append((i, k, item)) + self.components.append(dummy_component) + elif section_must_be_skipped: + self.components.append(dummy_component) + else: + component = create_setting_component(k) + self.component_dict[k] = component + self.components.append(component) + + if current_tab is not None: + current_row.__exit__() + current_tab.__exit__() + + with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"): + loadsave.create_ui() + + with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"): + gr.HTML('Download system info
    (or open as text in a new page)', elem_id="sysinfo_download") + + with gr.Row(): + with gr.Column(scale=1): + sysinfo_check_file = gr.File(label="Check system info for validity", type='binary') + with gr.Column(scale=1): + sysinfo_check_output = gr.HTML("", elem_id="sysinfo_validity") + with gr.Column(scale=100): + pass + + with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + with gr.Row(): + unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") + reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") + + with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): + gr.HTML(shared.html("licenses.html"), elem_id="licenses") + + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + + self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + + unload_sd_model.click( + fn=sd_models.unload_model_weights, + inputs=[], + outputs=[] + ) + + reload_sd_model.click( + fn=sd_models.reload_model_weights, + inputs=[], + outputs=[] + ) + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + restart_gradio.click( + fn=shared.state.request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + def check_file(x): + if x is None: + return '' + + if sysinfo.check(x.decode('utf8', errors='ignore')): + return 'Valid' + + return 'Invalid' + + sysinfo_check_file.change( + fn=check_file, + inputs=[sysinfo_check_file], + outputs=[sysinfo_check_output], + ) + + self.interface = settings_interface + + def add_quicksettings(self): + with gr.Row(elem_id="quicksettings", variant="compact"): + for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])): + component = create_setting_component(k, is_quicksettings=True) + self.component_dict[k] = component + + def add_functionality(self, demo): + self.submit.click( + fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]), + inputs=self.components, + outputs=[self.text_settings, self.result], + ) + + for _i, k, _item in self.quicksettings_list: + component = self.component_dict[k] + info = opts.data_labels[k] + + if isinstance(component, gr.Textbox): + methods = [component.submit, component.blur] + elif hasattr(component, 'release'): + methods = [component.release] + else: + methods = [component.change] + + for method in methods: + method( + fn=lambda value, k=k: self.run_settings_single(value, key=k), + inputs=[component], + outputs=[component, self.text_settings], + show_progress=info.refresh is not None, + ) + + button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) + button_set_checkpoint.click( + fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'), + _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", + inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component], + outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict] + + def get_settings_values(): + return [get_value_for_setting(key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[self.component_dict[k] for k in component_keys], + queue=False, + ) diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py new file mode 100644 index 0000000000000000000000000000000000000000..91d5919983bd6b35d988a0b2a1d6e3da2d1ba046 --- /dev/null +++ b/modules/ui_tempdir.py @@ -0,0 +1,85 @@ +import os +import tempfile +from collections import namedtuple +from pathlib import Path + +import gradio.components + +from PIL import PngImagePlugin + +from modules import shared + + +Savedfile = namedtuple("Savedfile", ["name"]) + + +def register_tmp_file(gradio, filename): + if hasattr(gradio, 'temp_file_sets'): # gradio 3.15 + gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)} + + if hasattr(gradio, 'temp_dirs'): # gradio 3.9 + gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))} + + +def check_tmp_file(gradio, filename): + if hasattr(gradio, 'temp_file_sets'): + return any(filename in fileset for fileset in gradio.temp_file_sets) + + if hasattr(gradio, 'temp_dirs'): + return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs) + + return False + + +def save_pil_to_file(self, pil_image, dir=None, format="png"): + already_saved_as = getattr(pil_image, 'already_saved_as', None) + if already_saved_as and os.path.isfile(already_saved_as): + register_tmp_file(shared.demo, already_saved_as) + filename = already_saved_as + + if not shared.opts.save_images_add_number: + filename += f'?{os.path.getmtime(already_saved_as)}' + + return filename + + if shared.opts.temp_dir != "": + dir = shared.opts.temp_dir + + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in pil_image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + + file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) + pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) + return file_obj.name + + +# override save to file function so that it also writes PNG info +gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file + + +def on_tmpdir_changed(): + if shared.opts.temp_dir == "" or shared.demo is None: + return + + os.makedirs(shared.opts.temp_dir, exist_ok=True) + + register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x")) + + +def cleanup_tmpdr(): + temp_dir = shared.opts.temp_dir + if temp_dir == "" or not os.path.isdir(temp_dir): + return + + for root, _, files in os.walk(temp_dir, topdown=False): + for name in files: + _, extension = os.path.splitext(name) + if extension != ".png": + continue + + filename = os.path.join(root, name) + os.remove(filename) diff --git a/modules/upscaler.py b/modules/upscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..e682bbaa26cd05fa8f00f6e6ca438a8c53f7d47b --- /dev/null +++ b/modules/upscaler.py @@ -0,0 +1,144 @@ +import os +from abc import abstractmethod + +import PIL +from PIL import Image + +import modules.shared +from modules import modelloader, shared + +LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) + + +class Upscaler: + name = None + model_path = None + model_name = None + model_url = None + enable = True + filter = None + model = None + user_path = None + scalers: [] + tile = True + + def __init__(self, create_dirs=False): + self.mod_pad_h = None + self.tile_size = modules.shared.opts.ESRGAN_tile + self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap + self.device = modules.shared.device + self.img = None + self.output = None + self.scale = 1 + self.half = not modules.shared.cmd_opts.no_half + self.pre_pad = 0 + self.mod_scale = None + self.model_download_path = None + + if self.model_path is None and self.name: + self.model_path = os.path.join(shared.models_path, self.name) + if self.model_path and create_dirs: + os.makedirs(self.model_path, exist_ok=True) + + try: + import cv2 # noqa: F401 + self.can_tile = True + except Exception: + pass + + @abstractmethod + def do_upscale(self, img: PIL.Image, selected_model: str): + return img + + def upscale(self, img: PIL.Image, scale, selected_model: str = None): + self.scale = scale + dest_w = int((img.width * scale) // 8 * 8) + dest_h = int((img.height * scale) // 8 * 8) + + for _ in range(3): + shape = (img.width, img.height) + + img = self.do_upscale(img, selected_model) + + if shape == (img.width, img.height): + break + + if img.width >= dest_w and img.height >= dest_h: + break + + if img.width != dest_w or img.height != dest_h: + img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) + + return img + + @abstractmethod + def load_model(self, path: str): + pass + + def find_models(self, ext_filter=None) -> list: + return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter) + + def update_status(self, prompt): + print(f"\nextras: {prompt}", file=shared.progress_print_out) + + +class UpscalerData: + name = None + data_path = None + scale: int = 4 + scaler: Upscaler = None + model: None + + def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): + self.name = name + self.data_path = path + self.local_data_path = path + self.scaler = upscaler + self.scale = scale + self.model = model + + +class UpscalerNone(Upscaler): + name = "None" + scalers = [] + + def load_model(self, path): + pass + + def do_upscale(self, img, selected_model=None): + return img + + def __init__(self, dirname=None): + super().__init__(False) + self.scalers = [UpscalerData("None", None, self)] + + +class UpscalerLanczos(Upscaler): + scalers = [] + + def do_upscale(self, img, selected_model=None): + return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS) + + def load_model(self, _): + pass + + def __init__(self, dirname=None): + super().__init__(False) + self.name = "Lanczos" + self.scalers = [UpscalerData("Lanczos", None, self)] + + +class UpscalerNearest(Upscaler): + scalers = [] + + def do_upscale(self, img, selected_model=None): + return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST) + + def load_model(self, _): + pass + + def __init__(self, dirname=None): + super().__init__(False) + self.name = "Nearest" + self.scalers = [UpscalerData("Nearest", None, self)] diff --git a/modules/xlmr.py b/modules/xlmr.py new file mode 100644 index 0000000000000000000000000000000000000000..a407a3cade8198bd8600bc7c9bbf8d778520a28c --- /dev/null +++ b/modules/xlmr.py @@ -0,0 +1,137 @@ +from transformers import BertPreTrainedModel, BertConfig +import torch.nn as nn +import torch +from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig +from transformers import XLMRobertaModel,XLMRobertaTokenizer +from typing import Optional + +class BertSeriesConfig(BertConfig): + def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): + + super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + + +class BertSeriesModelWithTransformation(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + config_class = BertSeriesConfig + + def __init__(self, config=None, **kargs): + # modify initialization for autoloading + if config is None: + config = XLMRobertaConfig() + config.attention_probs_dropout_prob= 0.1 + config.bos_token_id=0 + config.eos_token_id=2 + config.hidden_act='gelu' + config.hidden_dropout_prob=0.1 + config.hidden_size=1024 + config.initializer_range=0.02 + config.intermediate_size=4096 + config.layer_norm_eps=1e-05 + config.max_position_embeddings=514 + + config.num_attention_heads=16 + config.num_hidden_layers=24 + config.output_past=True + config.pad_token_id=1 + config.position_embedding_type= "absolute" + + config.type_vocab_size= 1 + config.use_cache=True + config.vocab_size= 250002 + config.project_dim = 768 + config.learn_encoder = False + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size,config.project_dim) + self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') + self.pooler = lambda x: x[:,0] + self.post_init() + + def encode(self,c): + device = next(self.parameters()).device + text = self.tokenizer(c, + truncation=True, + max_length=77, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") + text["input_ids"] = torch.tensor(text["input_ids"]).to(device) + text["attention_mask"] = torch.tensor( + text['attention_mask']).to(device) + features = self(**text) + return features['projection_state'] + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) : + r""" + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + # last module outputs + sequence_output = outputs[0] + + + # project every module + sequence_output_ln = self.pre_LN(sequence_output) + + # pooler + pooler_output = self.pooler(sequence_output_ln) + pooler_output = self.transformation(pooler_output) + projection_state = self.transformation(outputs.last_hidden_state) + + return { + 'pooler_output':pooler_output, + 'last_hidden_state':outputs.last_hidden_state, + 'hidden_states':outputs.hidden_states, + 'attentions':outputs.attentions, + 'projection_state':projection_state, + 'sequence_out': sequence_output + } + + +class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): + base_model_prefix = 'roberta' + config_class= RobertaSeriesConfig diff --git a/outputs/txt2img-images/2023-07-30/00000-4104476258.png b/outputs/txt2img-images/2023-07-30/00000-4104476258.png new file mode 100644 index 0000000000000000000000000000000000000000..f49b5d64c94c7f822100b351bd95eb86d924132e Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00000-4104476258.png differ diff --git a/outputs/txt2img-images/2023-07-30/00001-1264812310.png b/outputs/txt2img-images/2023-07-30/00001-1264812310.png new file mode 100644 index 0000000000000000000000000000000000000000..8a8e58089098d6e74503d3f1f471f144fa2a861d Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00001-1264812310.png differ diff --git a/outputs/txt2img-images/2023-07-30/00002-629074369.png b/outputs/txt2img-images/2023-07-30/00002-629074369.png new file mode 100644 index 0000000000000000000000000000000000000000..3d502a505aabe229845ff00aabff76444374e449 Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00002-629074369.png differ diff --git a/outputs/txt2img-images/2023-07-30/00003-3929529382.png b/outputs/txt2img-images/2023-07-30/00003-3929529382.png new file mode 100644 index 0000000000000000000000000000000000000000..85a97f627363eb816082ff9cd4c161deb31ca7f3 Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00003-3929529382.png differ diff --git a/outputs/txt2img-images/2023-07-30/00004-2891905160.png b/outputs/txt2img-images/2023-07-30/00004-2891905160.png new file mode 100644 index 0000000000000000000000000000000000000000..66d1e1961810fc990655287e4ebf7754e0ab01fc Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00004-2891905160.png differ diff --git a/outputs/txt2img-images/2023-07-30/00005-1703927525.png b/outputs/txt2img-images/2023-07-30/00005-1703927525.png new file mode 100644 index 0000000000000000000000000000000000000000..591c3e5369a8e768e4ead0d134f6dc5c922fbc2e Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00005-1703927525.png differ diff --git a/outputs/txt2img-images/2023-07-30/00006-1703927525.png b/outputs/txt2img-images/2023-07-30/00006-1703927525.png new file mode 100644 index 0000000000000000000000000000000000000000..591c3e5369a8e768e4ead0d134f6dc5c922fbc2e Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00006-1703927525.png differ diff --git a/outputs/txt2img-images/2023-07-30/00007-1703927525.png b/outputs/txt2img-images/2023-07-30/00007-1703927525.png new file mode 100644 index 0000000000000000000000000000000000000000..591c3e5369a8e768e4ead0d134f6dc5c922fbc2e Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00007-1703927525.png differ diff --git a/outputs/txt2img-images/2023-07-30/00008-1703927525.png b/outputs/txt2img-images/2023-07-30/00008-1703927525.png new file mode 100644 index 0000000000000000000000000000000000000000..591c3e5369a8e768e4ead0d134f6dc5c922fbc2e Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00008-1703927525.png differ diff --git a/outputs/txt2img-images/2023-07-30/00009-1703927525.jpg b/outputs/txt2img-images/2023-07-30/00009-1703927525.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bcc96ac79ad29c6fefe9a84035312e5778ca7d07 Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00009-1703927525.jpg differ diff --git a/outputs/txt2img-images/2023-07-30/00010-210755578.jpg b/outputs/txt2img-images/2023-07-30/00010-210755578.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1149fed542fc7327a6576f75bc23c14aada1b371 Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00010-210755578.jpg differ diff --git a/outputs/txt2img-images/2023-07-30/00011-3978311133.jpg b/outputs/txt2img-images/2023-07-30/00011-3978311133.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a0c9f17bae2a672981f744718e6557321d8e7d16 Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00011-3978311133.jpg differ diff --git a/outputs/txt2img-images/2023-07-30/00012-3786155085.jpg b/outputs/txt2img-images/2023-07-30/00012-3786155085.jpg new file mode 100644 index 0000000000000000000000000000000000000000..94d80027304656a12174f9500bc2067935c4579c Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00012-3786155085.jpg differ diff --git a/outputs/txt2img-images/2023-07-30/00013-445379948.jpg b/outputs/txt2img-images/2023-07-30/00013-445379948.jpg new file mode 100644 index 0000000000000000000000000000000000000000..87eee4874e220ac2d6419d12451e5916ce0d031b Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00013-445379948.jpg differ diff --git a/outputs/txt2img-images/2023-07-30/00014-3277595636.png b/outputs/txt2img-images/2023-07-30/00014-3277595636.png new file mode 100644 index 0000000000000000000000000000000000000000..599554288e13e29b5e46fa2370f60d3e3ed6085f Binary files /dev/null and b/outputs/txt2img-images/2023-07-30/00014-3277595636.png differ diff --git a/package.json b/package.json new file mode 100644 index 0000000000000000000000000000000000000000..c0ba406787db88b636d72767866274554f77381b --- /dev/null +++ b/package.json @@ -0,0 +1,11 @@ +{ + "name": "stable-diffusion-webui", + "version": "0.0.0", + "devDependencies": { + "eslint": "^8.40.0" + }, + "scripts": { + "lint": "eslint .", + "fix": "eslint --fix ." + } +} diff --git a/params.txt b/params.txt new file mode 100644 index 0000000000000000000000000000000000000000..4e21454c5a2b9bdee365e4743f4a21fb81997f05 --- /dev/null +++ b/params.txt @@ -0,0 +1,3 @@ +masterpiece, best quality, +Negative prompt: (worst quality, low quality:1.4) +Steps: 20, Sampler: DPM++ 2M Karras, CFG scale: 7, Seed: 3277595636, Size: 512x768, Model hash: ee5e7d0285, Model: SCH_Excelsior, Version: 1.5.0 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..80541a8f35319e15d837ea8bdd3ffc4de25776ea --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,35 @@ +[tool.ruff] + +target-version = "py39" + +extend-select = [ + "B", + "C", + "I", + "W", +] + +exclude = [ + "extensions", + "extensions-disabled", +] + +ignore = [ + "E501", # Line too long + "E731", # Do not assign a `lambda` expression, use a `def` + + "I001", # Import block is un-sorted or un-formatted + "C901", # Function is too complex + "C408", # Rewrite as a literal + "W605", # invalid escape sequence, messes with some docstrings +] + +[tool.ruff.per-file-ignores] +"webui.py" = ["E402"] # Module level import not at top of file + +[tool.ruff.flake8-bugbear] +# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. +extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"] + +[tool.pytest.ini_options] +base_url = "http://127.0.0.1:7860" diff --git a/repositories/BLIP/BLIP.gif b/repositories/BLIP/BLIP.gif new file mode 100644 index 0000000000000000000000000000000000000000..f97959778a4d3a9c1d5c06793c96d96204fe2081 --- /dev/null +++ b/repositories/BLIP/BLIP.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7757a1a1133807158ec4e696a8187f289e64c30a86aa470d8e0a93948a02be22 +size 6707660 diff --git a/repositories/BLIP/CODEOWNERS b/repositories/BLIP/CODEOWNERS new file mode 100644 index 0000000000000000000000000000000000000000..522fa4a0f715cd0328b9b9dbacae00e060193f43 --- /dev/null +++ b/repositories/BLIP/CODEOWNERS @@ -0,0 +1,2 @@ +# Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. +#ECCN:Open Source diff --git a/repositories/BLIP/CODE_OF_CONDUCT.md b/repositories/BLIP/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..b6724718c9512d730bb7f1bcc5848cd420241407 --- /dev/null +++ b/repositories/BLIP/CODE_OF_CONDUCT.md @@ -0,0 +1,105 @@ +# Salesforce Open Source Community Code of Conduct + +## About the Code of Conduct + +Equality is a core value at Salesforce. We believe a diverse and inclusive +community fosters innovation and creativity, and are committed to building a +culture where everyone feels included. + +Salesforce open-source projects are committed to providing a friendly, safe, and +welcoming environment for all, regardless of gender identity and expression, +sexual orientation, disability, physical appearance, body size, ethnicity, nationality, +race, age, religion, level of experience, education, socioeconomic status, or +other similar personal characteristics. + +The goal of this code of conduct is to specify a baseline standard of behavior so +that people with different social values and communication styles can work +together effectively, productively, and respectfully in our open source community. +It also establishes a mechanism for reporting issues and resolving conflicts. + +All questions and reports of abusive, harassing, or otherwise unacceptable behavior +in a Salesforce open-source project may be reported by contacting the Salesforce +Open Source Conduct Committee at ossconduct@salesforce.com. + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of gender +identity and expression, sexual orientation, disability, physical appearance, +body size, ethnicity, nationality, race, age, religion, level of experience, education, +socioeconomic status, or other similar personal characteristics. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy toward other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Personal attacks, insulting/derogatory comments, or trolling +* Public or private harassment +* Publishing, or threatening to publish, others' private information—such as +a physical or electronic address—without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting +* Advocating for or encouraging any of the above behaviors + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned with this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project email +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the Salesforce Open Source Conduct Committee +at ossconduct@salesforce.com. All complaints will be reviewed and investigated +and will result in a response that is deemed necessary and appropriate to the +circumstances. The committee is obligated to maintain confidentiality with +regard to the reporter of an incident. Further details of specific enforcement +policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership and the Salesforce Open Source Conduct +Committee. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], +version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. +It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], +[CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. + +This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. + +[contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) +[golang-coc]: https://golang.org/conduct +[cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md +[microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ +[cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ diff --git a/repositories/BLIP/LICENSE.txt b/repositories/BLIP/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..a63e87f4e1e90c96861648a16a7304d97d3c3f7b --- /dev/null +++ b/repositories/BLIP/LICENSE.txt @@ -0,0 +1,12 @@ +Copyright (c) 2022, Salesforce.com, Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/repositories/BLIP/README.md b/repositories/BLIP/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7a86ebfc21536b97fde1789f0c9ec4a53d2bdd77 --- /dev/null +++ b/repositories/BLIP/README.md @@ -0,0 +1,114 @@ +## BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation + + + +This is the PyTorch code of the BLIP paper [[blog](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)]. The code has been tested on PyTorch 1.10. +To install the dependencies, run
    pip install -r requirements.txt
    + +Catalog: +- [x] Inference demo +- [x] Pre-trained and finetuned checkpoints +- [x] Finetuning code for Image-Text Retrieval, Image Captioning, VQA, and NLVR2 +- [x] Pre-training code +- [x] Zero-shot video-text retrieval +- [x] Download of bootstrapped pre-training datasets + + +### Inference demo: +Run our interactive demo using [Colab notebook](https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb) (no GPU needed). +The demo includes code for: +1. Image captioning +2. Open-ended visual question answering +3. Multimodal / unimodal feature extraction +4. Image-text matching + +Try out the [Web demo](https://huggingface.co/spaces/Salesforce/BLIP), integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). + +Replicate web demo and Docker image is also available at [![Replicate](https://replicate.com/salesforce/blip/badge)](https://replicate.com/salesforce/blip) + +### Pre-trained checkpoints: +Num. pre-train images | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L +--- | :---: | :---: | :---: +14M | Download| - | - +129M | Download| Download | Download + +### Finetuned checkpoints: +Task | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L +--- | :---: | :---: | :---: +Image-Text Retrieval (COCO) | Download| - | Download +Image-Text Retrieval (Flickr30k) | Download| - | Download +Image Captioning (COCO) | - | Download| Download | +VQA | Download| Download | - +NLVR2 | Download| - | - + + +### Image-Text Retrieval: +1. Download COCO and Flickr30k datasets from the original websites, and set 'image_root' in configs/retrieval_{dataset}.yaml accordingly. +2. To evaluate the finetuned BLIP model on COCO, run: +
    python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
    +--config ./configs/retrieval_coco.yaml \
    +--output_dir output/retrieval_coco \
    +--evaluate
    +3. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/retrieval_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run: +
    python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
    +--config ./configs/retrieval_coco.yaml \
    +--output_dir output/retrieval_coco 
    + +### Image-Text Captioning: +1. Download COCO and NoCaps datasets from the original websites, and set 'image_root' in configs/caption_coco.yaml and configs/nocaps.yaml accordingly. +2. To evaluate the finetuned BLIP model on COCO, run: +
    python -m torch.distributed.run --nproc_per_node=8 train_caption.py --evaluate
    +3. To evaluate the finetuned BLIP model on NoCaps, generate results with: (evaluation needs to be performed on official server) +
    python -m torch.distributed.run --nproc_per_node=8 eval_nocaps.py 
    +4. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/caption_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run: +
    python -m torch.distributed.run --nproc_per_node=8 train_caption.py 
    + +### VQA: +1. Download VQA v2 dataset and Visual Genome dataset from the original websites, and set 'vqa_root' and 'vg_root' in configs/vqa.yaml. +2. To evaluate the finetuned BLIP model, generate results with: (evaluation needs to be performed on official server) +
    python -m torch.distributed.run --nproc_per_node=8 train_vqa.py --evaluate
    +3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/vqa.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run: +
    python -m torch.distributed.run --nproc_per_node=16 train_vqa.py 
    + +### NLVR2: +1. Download NLVR2 dataset from the original websites, and set 'image_root' in configs/nlvr.yaml. +2. To evaluate the finetuned BLIP model, run +
    python -m torch.distributed.run --nproc_per_node=8 train_nlvr.py --evaluate
    +3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/nlvr.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run: +
    python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py 
    + +### Finetune with ViT-L: +In order to finetune a model with ViT-L, simply change the config file to set 'vit' as large. Batch size and learning rate may also need to be adjusted accordingly (please see the paper's appendix for hyper-parameter details). Gradient checkpoint can also be activated in the config file to reduce GPU memory usage. + +### Pre-train: +1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}. +2. In configs/pretrain.yaml, set 'train_file' as the paths for the json files . +3. Pre-train the model using 8 A100 GPUs: +
    python -m torch.distributed.run --nproc_per_node=8 pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain 
    + +### Zero-shot video-text retrieval: +1. Download MSRVTT dataset following the instructions from https://github.com/salesforce/ALPRO, and set 'video_root' accordingly in configs/retrieval_msrvtt.yaml. +2. Install [decord](https://github.com/dmlc/decord) with
    pip install decord
    +3. To perform zero-shot evaluation, run +
    python -m torch.distributed.run --nproc_per_node=8 eval_retrieval_video.py
    + +### Pre-training datasets download: +We provide bootstrapped pre-training datasets as json files. Each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'url': url_of_image, 'caption': text_of_image}. + +Image source | Filtered web caption | Filtered synthetic caption by ViT-B | Filtered synthetic caption by ViT-L +--- | :---: | :---: | :---: +CC3M+CC12M+SBU | Download| Download| Download +LAION115M | Download| Download| Download + +### Citation +If you find this code to be useful for your research, please consider citing. +
    +@inproceedings{li2022blip,
    +      title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation}, 
    +      author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
    +      year={2022},
    +      booktitle={ICML},
    +}
    + +### Acknowledgement +The implementation of BLIP relies on resources from ALBEF, Huggingface Transformers, and timm. We thank the original authors for their open-sourcing. diff --git a/repositories/BLIP/SECURITY.md b/repositories/BLIP/SECURITY.md new file mode 100644 index 0000000000000000000000000000000000000000..8249025739809035264e7776583b2f3ec100553c --- /dev/null +++ b/repositories/BLIP/SECURITY.md @@ -0,0 +1,7 @@ +## Security + +Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) +as soon as it is discovered. This library limits its runtime dependencies in +order to reduce the total cost of ownership as much as can be, but all consumers +should remain vigilant and have their security stakeholders review all third-party +products (3PP) like this one and their dependencies. diff --git a/repositories/BLIP/cog.yaml b/repositories/BLIP/cog.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c1dfcc430a4cab0fdd2a60a682336219a61c4a4f --- /dev/null +++ b/repositories/BLIP/cog.yaml @@ -0,0 +1,17 @@ +build: + gpu: true + cuda: "11.1" + python_version: "3.8" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + python_packages: + - "ipython==7.30.1" + - "torchvision==0.11.1" + - "torch==1.10.0" + - "timm==0.4.12" + - "transformers==4.15.0" + - "fairscale==0.4.4" + - "pycocoevalcap==1.2" + +predict: "predict.py:Predictor" diff --git a/repositories/BLIP/configs/bert_config.json b/repositories/BLIP/configs/bert_config.json new file mode 100644 index 0000000000000000000000000000000000000000..3ef38aabc7f966b53079e9d559dc59e459cc0051 --- /dev/null +++ b/repositories/BLIP/configs/bert_config.json @@ -0,0 +1,21 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "type_vocab_size": 2, + "vocab_size": 30522, + "encoder_width": 768, + "add_cross_attention": true +} diff --git a/repositories/BLIP/configs/caption_coco.yaml b/repositories/BLIP/configs/caption_coco.yaml new file mode 100644 index 0000000000000000000000000000000000000000..42eab7030c0310ba2f265baf36fa1400aa6e5846 --- /dev/null +++ b/repositories/BLIP/configs/caption_coco.yaml @@ -0,0 +1,33 @@ +image_root: '/export/share/datasets/vision/coco/images/' +ann_root: 'annotation' +coco_gt_root: 'annotation/coco_gt' + +# set pretrained as a file path or an url +pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' + +# size of vit model; base or large +vit: 'base' +vit_grad_ckpt: False +vit_ckpt_layer: 0 +batch_size: 32 +init_lr: 1e-5 + +# vit: 'large' +# vit_grad_ckpt: True +# vit_ckpt_layer: 5 +# batch_size: 16 +# init_lr: 2e-6 + +image_size: 384 + +# generation configs +max_length: 20 +min_length: 5 +num_beams: 3 +prompt: 'a picture of ' + +# optimizer +weight_decay: 0.05 +min_lr: 0 +max_epoch: 5 +
    PathOld valueNew value
    {path}{old_value}{new_value}
    No changes