"
+ embedding_yet_to_be_embedded = False
+
+ is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
+ img_c = None
+
+ pbar = tqdm.tqdm(total=steps - initial_step)
+ try:
+ sd_hijack_checkpoint.add()
+
+ for i in range((steps-initial_step) * gradient_step):
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+ for j, batch in enumerate(dl):
+ # works as a drop_last=True for gradient accumulation
+ if j == max_steps_per_epoch:
+ break
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+
+ if clip_grad:
+ clip_grad_sched.step(embedding.step)
+
+ with devices.autocast():
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+ if use_weight:
+ w = batch.weight.to(devices.device, non_blocking=pin_memory)
+ c = shared.sd_model.cond_stage_model(batch.cond_text)
+
+ if is_training_inpainting_model:
+ if img_c is None:
+ img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
+
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ else:
+ cond = c
+
+ if use_weight:
+ loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step
+ del w
+ else:
+ loss = shared.sd_model.forward(x, cond)[0] / gradient_step
+ del x
+
+ _loss_step += loss.item()
+ scaler.scale(loss).backward()
+
+ # go back until we reach gradient accumulation steps
+ if (j + 1) % gradient_step != 0:
+ continue
+
+ if clip_grad:
+ clip_grad(embedding.vec, clip_grad_sched.learn_rate)
+
+ scaler.step(optimizer)
+ scaler.update()
+ embedding.step += 1
+ pbar.update()
+ optimizer.zero_grad(set_to_none=True)
+ loss_step = _loss_step
+ _loss_step = 0
+
+ steps_done = embedding.step + 1
+
+ epoch_num = embedding.step // steps_per_epoch
+ epoch_step = embedding.step % steps_per_epoch
+
+ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
+ pbar.set_description(description)
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
+ # Before saving, change name to match current checkpoint.
+ embedding_name_every = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
+ save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
+ embedding_yet_to_be_embedded = True
+
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
+ "loss": f"{loss_step:.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ do_not_reload_embeddings=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = batch.cond_text[0]
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
+
+ if unload:
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ if image is not None:
+ shared.state.assign_current_image(image)
+
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
+
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
+
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
+
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
+
+ title = "<{}>".format(data.get('name', '???'))
+
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
+
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.shorthash)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
+
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
+
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
+
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = embedding.step
+
+ shared.state.textinfo = f"""
+
+Loss: {loss_step:.7f}
+Step: {steps_done}
+Last prompt: {html.escape(batch.cond_text[0])}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+
+"""
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+ save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ except Exception:
+ print(traceback.format_exc(), file=sys.stderr)
+ pass
+ finally:
+ pbar.leave = False
+ pbar.close()
+ shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
+ sd_hijack_checkpoint.remove()
+
+ return embedding, filename
+
+
+def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+ old_embedding_name = embedding.name
+ old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
+ old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
+ old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
+ try:
+ embedding.sd_checkpoint = checkpoint.shorthash
+ embedding.sd_checkpoint_name = checkpoint.model_name
+ if remove_cached_checksum:
+ embedding.cached_checksum = None
+ embedding.name = embedding_name
+ embedding.optimizer_state_dict = optimizer.state_dict()
+ embedding.save(filename)
+ except:
+ embedding.sd_checkpoint = old_sd_checkpoint
+ embedding.sd_checkpoint_name = old_sd_checkpoint_name
+ embedding.name = old_embedding_name
+ embedding.cached_checksum = old_cached_checksum
+ raise
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b75f799e745fa693cda06763af80069324a964f
--- /dev/null
+++ b/modules/textual_inversion/ui.py
@@ -0,0 +1,45 @@
+import html
+
+import gradio as gr
+
+import modules.textual_inversion.textual_inversion
+import modules.textual_inversion.preprocess
+from modules import sd_hijack, shared
+
+
+def create_embedding(name, initialization_text, nvpt, overwrite_old):
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
+
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
+
+ return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
+
+
+def preprocess(*args):
+ modules.textual_inversion.preprocess.preprocess(*args)
+
+ return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
+
+
+def train_embedding(*args):
+
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
+
+ apply_optimizations = shared.opts.training_xattention_optimizations
+ try:
+ if not apply_optimizations:
+ sd_hijack.undo_optimizations()
+
+ embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
+
+ res = f"""
+Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
+Embedding saved to {html.escape(filename)}
+"""
+ return res, ""
+ except Exception:
+ raise
+ finally:
+ if not apply_optimizations:
+ sd_hijack.apply_optimizations()
+
diff --git a/modules/timer.py b/modules/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8187c28edea3d7ce30d1d8c086a6191eb49d960c
--- /dev/null
+++ b/modules/timer.py
@@ -0,0 +1,35 @@
+import time
+
+
+class Timer:
+ def __init__(self):
+ self.start = time.time()
+ self.records = {}
+ self.total = 0
+
+ def elapsed(self):
+ end = time.time()
+ res = end - self.start
+ self.start = end
+ return res
+
+ def record(self, category, extra_time=0):
+ e = self.elapsed()
+ if category not in self.records:
+ self.records[category] = 0
+
+ self.records[category] += e + extra_time
+ self.total += e + extra_time
+
+ def summary(self):
+ res = f"{self.total:.1f}s"
+
+ additions = [x for x in self.records.items() if x[1] >= 0.1]
+ if not additions:
+ return res
+
+ res += " ("
+ res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
+ res += ")"
+
+ return res
diff --git a/modules/txt2img.py b/modules/txt2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..3927d8538f06c1ed270c9a6cfd55d4bb15705ee5
--- /dev/null
+++ b/modules/txt2img.py
@@ -0,0 +1,69 @@
+import modules.scripts
+from modules import sd_samplers
+from modules.generation_parameters_copypaste import create_override_settings_dict
+from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
+ StableDiffusionProcessingImg2Img, process_images
+from modules.shared import opts, cmd_opts
+import modules.shared as shared
+import modules.processing as processing
+from modules.ui import plaintext_to_html
+
+
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
+ override_settings = create_override_settings_dict(override_settings_texts)
+
+ p = StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
+ outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
+ prompt=prompt,
+ styles=prompt_styles,
+ negative_prompt=negative_prompt,
+ seed=seed,
+ subseed=subseed,
+ subseed_strength=subseed_strength,
+ seed_resize_from_h=seed_resize_from_h,
+ seed_resize_from_w=seed_resize_from_w,
+ seed_enable_extras=seed_enable_extras,
+ sampler_name=sd_samplers.samplers[sampler_index].name,
+ batch_size=batch_size,
+ n_iter=n_iter,
+ steps=steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ restore_faces=restore_faces,
+ tiling=tiling,
+ enable_hr=enable_hr,
+ denoising_strength=denoising_strength if enable_hr else None,
+ hr_scale=hr_scale,
+ hr_upscaler=hr_upscaler,
+ hr_second_pass_steps=hr_second_pass_steps,
+ hr_resize_x=hr_resize_x,
+ hr_resize_y=hr_resize_y,
+ override_settings=override_settings,
+ )
+
+ p.scripts = modules.scripts.scripts_txt2img
+ p.script_args = args
+
+ if cmd_opts.enable_console_prompts:
+ print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
+
+ processed = modules.scripts.scripts_txt2img.run(p, *args)
+
+ if processed is None:
+ processed = process_images(p)
+
+ p.close()
+
+ shared.total_tqdm.clear()
+
+ generation_info_js = processed.js()
+ if opts.samples_log_stdout:
+ print(generation_info_js)
+
+ if opts.do_not_show_images:
+ processed.images = []
+
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
diff --git a/modules/ui.py b/modules/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..badf4975128985ad55bb69d6ee028adc0a61f97c
--- /dev/null
+++ b/modules/ui.py
@@ -0,0 +1,1798 @@
+import html
+import json
+import math
+import mimetypes
+import os
+import platform
+import random
+import sys
+import tempfile
+import time
+import traceback
+from functools import partial, reduce
+import warnings
+
+import gradio as gr
+import gradio.routes
+import gradio.utils
+import numpy as np
+from PIL import Image, PngImagePlugin
+from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
+
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
+from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
+from modules.paths import script_path, data_path
+
+from modules.shared import opts, cmd_opts, restricted_opts
+
+import modules.codeformer_model
+import modules.generation_parameters_copypaste as parameters_copypaste
+import modules.gfpgan_model
+import modules.hypernetworks.ui
+import modules.scripts
+import modules.shared as shared
+import modules.styles
+import modules.textual_inversion.ui
+from modules import prompt_parser
+from modules.images import save_image
+from modules.sd_hijack import model_hijack
+from modules.sd_samplers import samplers, samplers_for_img2img
+from modules.textual_inversion import textual_inversion
+import modules.hypernetworks.ui
+from modules.generation_parameters_copypaste import image_from_url_text
+import modules.extras
+
+warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
+
+# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
+mimetypes.init()
+mimetypes.add_type('application/javascript', '.js')
+
+if not cmd_opts.share and not cmd_opts.listen:
+ # fix gradio phoning home
+ gradio.utils.version_check = lambda: None
+ gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
+
+if cmd_opts.ngrok is not None:
+ import modules.ngrok as ngrok
+ print('ngrok authtoken detected, trying to connect...')
+ ngrok.connect(
+ cmd_opts.ngrok,
+ cmd_opts.port if cmd_opts.port is not None else 7860,
+ cmd_opts.ngrok_region
+ )
+
+
+def gr_show(visible=True):
+ return {"visible": visible, "__type__": "update"}
+
+
+sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
+sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
+
+css_hide_progressbar = """
+.wrap .m-12 svg { display:none!important; }
+.wrap .m-12::before { content:"Loading..." }
+.wrap .z-20 svg { display:none!important; }
+.wrap .z-20::before { content:"Loading..." }
+.wrap.cover-bg .z-20::before { content:"" }
+.progress-bar { display:none!important; }
+.meta-text { display:none!important; }
+.meta-text-center { display:none!important; }
+"""
+
+# Using constants for these since the variation selector isn't visible.
+# Important that they exactly match script.js for tooltip to work.
+random_symbol = '\U0001f3b2\ufe0f' # 🎲️
+reuse_symbol = '\u267b\ufe0f' # ♻️
+paste_symbol = '\u2199\ufe0f' # ↙
+refresh_symbol = '\U0001f504' # 🔄
+save_style_symbol = '\U0001f4be' # 💾
+apply_style_symbol = '\U0001f4cb' # 📋
+clear_prompt_symbol = '\U0001F5D1' # 🗑️
+extra_networks_symbol = '\U0001F3B4' # 🎴
+switch_values_symbol = '\U000021C5' # ⇅
+
+
+def plaintext_to_html(text):
+ return ui_common.plaintext_to_html(text)
+
+
+def send_gradio_gallery_to_image(x):
+ if len(x) == 0:
+ return None
+ return image_from_url_text(x[0])
+
+def visit(x, func, path=""):
+ if hasattr(x, 'children'):
+ for c in x.children:
+ visit(c, func, path)
+ elif x.label is not None:
+ func(path + "/" + str(x.label), x)
+
+
+def add_style(name: str, prompt: str, negative_prompt: str):
+ if name is None:
+ return [gr_show() for x in range(4)]
+
+ style = modules.styles.PromptStyle(name, prompt, negative_prompt)
+ shared.prompt_styles.styles[style.name] = style
+ # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
+ # reserialize all styles every time we save them
+ shared.prompt_styles.save_styles(shared.styles_filename)
+
+ return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
+
+
+def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
+ from modules import processing, devices
+
+ if not enable:
+ return ""
+
+ p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
+
+ with devices.autocast():
+ p.init([""], [0], [0])
+
+ return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}"
+
+
+def apply_styles(prompt, prompt_neg, styles):
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
+ prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
+
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
+
+
+def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
+ if mode in {0, 1, 3, 4}:
+ return [interrogation_function(ii_singles[mode]), None]
+ elif mode == 2:
+ return [interrogation_function(ii_singles[mode]["image"]), None]
+ elif mode == 5:
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
+ images = shared.listfiles(ii_input_dir)
+ print(f"Will process {len(images)} images.")
+ if ii_output_dir != "":
+ os.makedirs(ii_output_dir, exist_ok=True)
+ else:
+ ii_output_dir = ii_input_dir
+
+ for image in images:
+ img = Image.open(image)
+ filename = os.path.basename(image)
+ left, _ = os.path.splitext(filename)
+ print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
+
+ return [gr.update(), None]
+
+
+def interrogate(image):
+ prompt = shared.interrogator.interrogate(image.convert("RGB"))
+ return gr.update() if prompt is None else prompt
+
+
+def interrogate_deepbooru(image):
+ prompt = deepbooru.model.tag(image)
+ return gr.update() if prompt is None else prompt
+
+
+def create_seed_inputs(target_interface):
+ with FormRow(elem_id=target_interface + '_seed_row'):
+ seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
+ seed.style(container=False)
+ random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
+ reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
+
+ with gr.Group(elem_id=target_interface + '_subseed_show_box'):
+ seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
+
+ # Components to show/hide based on the 'Extra' checkbox
+ seed_extras = []
+
+ with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
+ seed_extras.append(seed_extra_row_1)
+ subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
+ subseed.style(container=False)
+ random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
+ reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
+ subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
+
+ with FormRow(visible=False) as seed_extra_row_2:
+ seed_extras.append(seed_extra_row_2)
+ seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
+ seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
+
+ random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
+ random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
+
+ def change_visibility(show):
+ return {comp: gr_show(show) for comp in seed_extras}
+
+ seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
+
+ return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
+
+
+
+def connect_clear_prompt(button):
+ """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
+ button.click(
+ _js="clear_prompt",
+ fn=None,
+ inputs=[],
+ outputs=[],
+ )
+
+
+def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
+ """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
+ (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
+ was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
+ def copy_seed(gen_info_string: str, index):
+ res = -1
+
+ try:
+ gen_info = json.loads(gen_info_string)
+ index -= gen_info.get('index_of_first_image', 0)
+
+ if is_subseed and gen_info.get('subseed_strength', 0) > 0:
+ all_subseeds = gen_info.get('all_subseeds', [-1])
+ res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
+ else:
+ all_seeds = gen_info.get('all_seeds', [-1])
+ res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
+
+ except json.decoder.JSONDecodeError as e:
+ if gen_info_string != '':
+ print("Error parsing JSON generation info:", file=sys.stderr)
+ print(gen_info_string, file=sys.stderr)
+
+ return [res, gr_show(False)]
+
+ reuse_seed.click(
+ fn=copy_seed,
+ _js="(x, y) => [x, selected_gallery_index()]",
+ show_progress=False,
+ inputs=[generation_info, dummy_component],
+ outputs=[seed, dummy_component]
+ )
+
+
+def update_token_counter(text, steps):
+ try:
+ text, _ = extra_networks.parse_prompt(text)
+
+ _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
+ prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
+
+ except Exception:
+ # a parsing error can happen here during typing, and we don't want to bother the user with
+ # messages related to it in console
+ prompt_schedules = [[[steps, text]]]
+
+ flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
+ prompts = [prompt_text for step, prompt_text in flat_prompts]
+ token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
+ return f"{token_count}/{max_length}"
+
+
+def create_toprow(is_img2img):
+ id_part = "img2img" if is_img2img else "txt2img"
+
+ with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
+ with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
+
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
+
+ button_interrogate = None
+ button_deepbooru = None
+ if is_img2img:
+ with gr.Column(scale=1, elem_id="interrogate_col"):
+ button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
+ button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
+
+ with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
+ with gr.Row(elem_id=f"{id_part}_generate_box"):
+ interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
+ skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
+ submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
+
+ skip.click(
+ fn=lambda: shared.state.skip(),
+ inputs=[],
+ outputs=[],
+ )
+
+ interrupt.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
+ with gr.Row(elem_id=f"{id_part}_tools"):
+ paste = ToolButton(value=paste_symbol, elem_id="paste")
+ clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+ extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
+ prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
+ save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
+
+ token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter")
+ token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+ negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter")
+ negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
+
+ clear_prompt_button.click(
+ fn=lambda *x: x,
+ _js="confirm_clear_prompt",
+ inputs=[prompt, negative_prompt],
+ outputs=[prompt, negative_prompt],
+ )
+
+ with gr.Row(elem_id=f"{id_part}_styles_row"):
+ prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
+ create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
+
+ return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
+
+
+def setup_progressbar(*args, **kwargs):
+ pass
+
+
+def apply_setting(key, value):
+ if value is None:
+ return gr.update()
+
+ if shared.cmd_opts.freeze_settings:
+ return gr.update()
+
+ # dont allow model to be swapped when model hash exists in prompt
+ if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
+ return gr.update()
+
+ if key == "sd_model_checkpoint":
+ ckpt_info = sd_models.get_closet_checkpoint_match(value)
+
+ if ckpt_info is not None:
+ value = ckpt_info.title
+ else:
+ return gr.update()
+
+ comp_args = opts.data_labels[key].component_args
+ if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
+ return
+
+ valtype = type(opts.data_labels[key].default)
+ oldval = opts.data.get(key, None)
+ opts.data[key] = valtype(value) if valtype != type(None) else value
+ if oldval != value and opts.data_labels[key].onchange is not None:
+ opts.data_labels[key].onchange()
+
+ opts.save(shared.config_filename)
+ return getattr(opts, key)
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
+
+ return gr.update(**(args or {}))
+
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
+
+
+def create_output_panel(tabname, outdir):
+ return ui_common.create_output_panel(tabname, outdir)
+
+
+def create_sampler_and_steps_selection(choices, tabname):
+ if opts.samplers_in_dropdown:
+ with FormRow(elem_id=f"sampler_selection_{tabname}"):
+ sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
+ else:
+ with FormGroup(elem_id=f"sampler_selection_{tabname}"):
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
+ sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+
+ return steps, sampler_index
+
+
+def ordered_ui_categories():
+ user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
+
+ for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
+ yield category
+
+
+def get_value_for_setting(key):
+ value = getattr(opts, key)
+
+ info = opts.data_labels[key]
+ args = info.component_args() if callable(info.component_args) else info.component_args or {}
+ args = {k: v for k, v in args.items() if k not in {'precision'}}
+
+ return gr.update(value=value, **args)
+
+
+def create_override_settings_dropdown(tabname, row):
+ dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
+
+ dropdown.change(
+ fn=lambda x: gr.Dropdown.update(visible=len(x) > 0),
+ inputs=[dropdown],
+ outputs=[dropdown],
+ )
+
+ return dropdown
+
+
+def create_ui():
+ import modules.img2img
+ import modules.txt2img
+
+ reload_javascript()
+
+ parameters_copypaste.reset()
+
+ modules.scripts.scripts_current = modules.scripts.scripts_txt2img
+ modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
+
+ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
+ txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
+
+ dummy_component = gr.Label(visible=False)
+ txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
+
+ with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
+ from modules import ui_extra_networks
+ extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
+
+ with gr.Row().style(equal_height=False):
+ with gr.Column(variant='compact', elem_id="txt2img_settings"):
+ for category in ordered_ui_categories():
+ if category == "sampler":
+ steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
+
+ elif category == "dimensions":
+ with FormRow():
+ with gr.Column(elem_id="txt2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
+ if opts.dimensions_and_batch_together:
+ with gr.Column(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "cfg":
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
+
+ elif category == "seed":
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
+
+ elif category == "checkboxes":
+ with FormRow(elem_id="txt2img_checkboxes", variant="compact"):
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
+ enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
+ hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
+
+ elif category == "hires_fix":
+ with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
+ with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
+ hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
+ hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
+
+ with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
+ hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
+ hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
+
+ elif category == "batch":
+ if not opts.dimensions_and_batch_together:
+ with FormRow(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "override_settings":
+ with FormRow(elem_id="txt2img_override_settings_row") as row:
+ override_settings = create_override_settings_dropdown('txt2img', row)
+
+ elif category == "scripts":
+ with FormGroup(elem_id="txt2img_script_container"):
+ custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+
+ hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
+ for input in hr_resolution_preview_inputs:
+ input.change(
+ fn=calc_resolution_hires,
+ inputs=hr_resolution_preview_inputs,
+ outputs=[hr_final_resolution],
+ show_progress=False,
+ )
+ input.change(
+ None,
+ _js="onCalcResolutionHires",
+ inputs=hr_resolution_preview_inputs,
+ outputs=[],
+ show_progress=False,
+ )
+
+ txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+
+ connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
+ connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
+
+ txt2img_args = dict(
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
+ _js="submit",
+ inputs=[
+ dummy_component,
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ txt2img_prompt_styles,
+ steps,
+ sampler_index,
+ restore_faces,
+ tiling,
+ batch_count,
+ batch_size,
+ cfg_scale,
+ seed,
+ subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
+ height,
+ width,
+ enable_hr,
+ denoising_strength,
+ hr_scale,
+ hr_upscaler,
+ hr_second_pass_steps,
+ hr_resize_x,
+ hr_resize_y,
+ override_settings,
+ ] + custom_inputs,
+
+ outputs=[
+ txt2img_gallery,
+ generation_info,
+ html_info,
+ html_log,
+ ],
+ show_progress=False,
+ )
+
+ txt2img_prompt.submit(**txt2img_args)
+ submit.click(**txt2img_args)
+
+ res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
+
+ txt_prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[
+ txt_prompt_img
+ ],
+ outputs=[
+ txt2img_prompt,
+ txt_prompt_img
+ ]
+ )
+
+ enable_hr.change(
+ fn=lambda x: gr_show(x),
+ inputs=[enable_hr],
+ outputs=[hr_options],
+ show_progress = False,
+ )
+
+ txt2img_paste_fields = [
+ (txt2img_prompt, "Prompt"),
+ (txt2img_negative_prompt, "Negative prompt"),
+ (steps, "Steps"),
+ (sampler_index, "Sampler"),
+ (restore_faces, "Face restoration"),
+ (cfg_scale, "CFG scale"),
+ (seed, "Seed"),
+ (width, "Size-1"),
+ (height, "Size-2"),
+ (batch_size, "Batch size"),
+ (subseed, "Variation seed"),
+ (subseed_strength, "Variation seed strength"),
+ (seed_resize_from_w, "Seed resize from-1"),
+ (seed_resize_from_h, "Seed resize from-2"),
+ (denoising_strength, "Denoising strength"),
+ (enable_hr, lambda d: "Denoising strength" in d),
+ (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
+ (hr_scale, "Hires upscale"),
+ (hr_upscaler, "Hires upscaler"),
+ (hr_second_pass_steps, "Hires steps"),
+ (hr_resize_x, "Hires resize-1"),
+ (hr_resize_y, "Hires resize-2"),
+ *modules.scripts.scripts_txt2img.infotext_fields
+ ]
+ parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
+ ))
+
+ txt2img_preview_params = [
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ steps,
+ sampler_index,
+ cfg_scale,
+ seed,
+ width,
+ height,
+ ]
+
+ token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
+ negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+
+ ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
+
+ modules.scripts.scripts_current = modules.scripts.scripts_img2img
+ modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
+
+ with gr.Blocks(analytics_enabled=False) as img2img_interface:
+ img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
+
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
+
+ with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
+ from modules import ui_extra_networks
+ extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
+
+ with FormRow().style(equal_height=False):
+ with gr.Column(variant='compact', elem_id="img2img_settings"):
+ copy_image_buttons = []
+ copy_image_destinations = {}
+
+ def add_copy_image_controls(tab_name, elem):
+ with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
+ gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}")
+
+ for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
+ if name == tab_name:
+ gr.Button(title, interactive=False)
+ copy_image_destinations[name] = elem
+ continue
+
+ button = gr.Button(title)
+ copy_image_buttons.append((button, name, elem))
+
+ with gr.Tabs(elem_id="mode_img2img"):
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('img2img', init_img)
+
+ with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
+ sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('sketch', sketch)
+
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('inpaint', init_img_with_mask)
+
+ with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
+ inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
+ inpaint_color_sketch_orig = gr.State(None)
+ add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
+
+ def update_orig(image, state):
+ if image is not None:
+ same_size = state is not None and state.size == image.size
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
+ edited = same_size and has_exact_match
+ return image if not edited or state is None else state
+
+ inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
+
+ with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
+ init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask")
+
+ with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
+ hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
+ gr.HTML(
+ f"Process images in a directory on the same machine where the server is running." +
+ f"
Use an empty output directory to save pictures normally instead of writing to the output directory." +
+ f"
Add inpaint batch mask directory to enable inpaint batch processing."
+ f"{hidden}
"
+ )
+ img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
+ img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
+ img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
+
+ def copy_image(img):
+ if isinstance(img, dict) and 'image' in img:
+ return img['image']
+
+ return img
+
+ for button, name, elem in copy_image_buttons:
+ button.click(
+ fn=copy_image,
+ inputs=[elem],
+ outputs=[copy_image_destinations[name]],
+ )
+ button.click(
+ fn=lambda: None,
+ _js="switch_to_"+name.replace(" ", "_"),
+ inputs=[],
+ outputs=[],
+ )
+
+ with FormRow():
+ resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
+
+ for category in ordered_ui_categories():
+ if category == "sampler":
+ steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
+
+ elif category == "dimensions":
+ with FormRow():
+ with gr.Column(elem_id="img2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
+
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
+ if opts.dimensions_and_batch_together:
+ with gr.Column(elem_id="img2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
+
+ elif category == "cfg":
+ with FormGroup():
+ with FormRow():
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
+ image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
+
+ elif category == "seed":
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
+
+ elif category == "checkboxes":
+ with FormRow(elem_id="img2img_checkboxes", variant="compact"):
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
+
+ elif category == "batch":
+ if not opts.dimensions_and_batch_together:
+ with FormRow(elem_id="img2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
+
+ elif category == "override_settings":
+ with FormRow(elem_id="img2img_override_settings_row") as row:
+ override_settings = create_override_settings_dropdown('img2img', row)
+
+ elif category == "scripts":
+ with FormGroup(elem_id="img2img_script_container"):
+ custom_inputs = modules.scripts.scripts_img2img.setup_ui()
+
+ elif category == "inpaint":
+ with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
+ with FormRow():
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+ mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
+
+ with FormRow():
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
+
+ with FormRow():
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
+
+ with FormRow():
+ with gr.Column():
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
+
+ with gr.Column(scale=4):
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
+
+ def select_img2img_tab(tab):
+ return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
+
+ for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
+ elem.select(
+ fn=lambda tab=i: select_img2img_tab(tab),
+ inputs=[],
+ outputs=[inpaint_controls, mask_alpha],
+ )
+
+ img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
+
+ connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
+ connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
+
+ img2img_prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[
+ img2img_prompt_img
+ ],
+ outputs=[
+ img2img_prompt,
+ img2img_prompt_img
+ ]
+ )
+
+ img2img_args = dict(
+ fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
+ _js="submit_img2img",
+ inputs=[
+ dummy_component,
+ dummy_component,
+ img2img_prompt,
+ img2img_negative_prompt,
+ img2img_prompt_styles,
+ init_img,
+ sketch,
+ init_img_with_mask,
+ inpaint_color_sketch,
+ inpaint_color_sketch_orig,
+ init_img_inpaint,
+ init_mask_inpaint,
+ steps,
+ sampler_index,
+ mask_blur,
+ mask_alpha,
+ inpainting_fill,
+ restore_faces,
+ tiling,
+ batch_count,
+ batch_size,
+ cfg_scale,
+ image_cfg_scale,
+ denoising_strength,
+ seed,
+ subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
+ height,
+ width,
+ resize_mode,
+ inpaint_full_res,
+ inpaint_full_res_padding,
+ inpainting_mask_invert,
+ img2img_batch_input_dir,
+ img2img_batch_output_dir,
+ img2img_batch_inpaint_mask_dir,
+ override_settings,
+ ] + custom_inputs,
+ outputs=[
+ img2img_gallery,
+ generation_info,
+ html_info,
+ html_log,
+ ],
+ show_progress=False,
+ )
+
+ interrogate_args = dict(
+ _js="get_img2img_tab_index",
+ inputs=[
+ dummy_component,
+ img2img_batch_input_dir,
+ img2img_batch_output_dir,
+ init_img,
+ sketch,
+ init_img_with_mask,
+ inpaint_color_sketch,
+ init_img_inpaint,
+ ],
+ outputs=[img2img_prompt, dummy_component],
+ )
+
+ img2img_prompt.submit(**img2img_args)
+ submit.click(**img2img_args)
+ res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
+
+ img2img_interrogate.click(
+ fn=lambda *args: process_interrogate(interrogate, *args),
+ **interrogate_args,
+ )
+
+ img2img_deepbooru.click(
+ fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
+ **interrogate_args,
+ )
+
+ prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
+ style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
+ style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
+
+ for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
+ button.click(
+ fn=add_style,
+ _js="ask_for_style_name",
+ # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
+ # the same number of parameters, but we only know the style-name after the JavaScript prompt
+ inputs=[dummy_component, prompt, negative_prompt],
+ outputs=[txt2img_prompt_styles, img2img_prompt_styles],
+ )
+
+ for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
+ button.click(
+ fn=apply_styles,
+ _js=js_func,
+ inputs=[prompt, negative_prompt, styles],
+ outputs=[prompt, negative_prompt, styles],
+ )
+
+ token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
+ negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+
+ ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
+
+ img2img_paste_fields = [
+ (img2img_prompt, "Prompt"),
+ (img2img_negative_prompt, "Negative prompt"),
+ (steps, "Steps"),
+ (sampler_index, "Sampler"),
+ (restore_faces, "Face restoration"),
+ (cfg_scale, "CFG scale"),
+ (image_cfg_scale, "Image CFG scale"),
+ (seed, "Seed"),
+ (width, "Size-1"),
+ (height, "Size-2"),
+ (batch_size, "Batch size"),
+ (subseed, "Variation seed"),
+ (subseed_strength, "Variation seed strength"),
+ (seed_resize_from_w, "Seed resize from-1"),
+ (seed_resize_from_h, "Seed resize from-2"),
+ (denoising_strength, "Denoising strength"),
+ (mask_blur, "Mask blur"),
+ *modules.scripts.scripts_img2img.infotext_fields
+ ]
+ parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
+ parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
+ ))
+
+ modules.scripts.scripts_current = None
+
+ with gr.Blocks(analytics_enabled=False) as extras_interface:
+ ui_postprocessing.create_ui()
+
+ with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
+ with gr.Row().style(equal_height=False):
+ with gr.Column(variant='panel'):
+ image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
+
+ with gr.Column(variant='panel'):
+ html = gr.HTML()
+ generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
+ html2 = gr.HTML()
+ with gr.Row():
+ buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
+
+ for tabname, button in buttons.items():
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
+ ))
+
+ image.change(
+ fn=wrap_gradio_call(modules.extras.run_pnginfo),
+ inputs=[image],
+ outputs=[html, generation_info, html2],
+ )
+
+ def update_interp_description(value):
+ interp_description_css = "{}
"
+ interp_descriptions = {
+ "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
+ "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
+ "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
+ }
+ return interp_descriptions[value]
+
+ with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
+ with gr.Row().style(equal_height=False):
+ with gr.Column(variant='compact'):
+ interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
+
+ with FormRow(elem_id="modelmerger_models"):
+ primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
+ create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
+
+ secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
+ create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
+
+ tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
+ create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
+
+ custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
+ interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
+ interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
+ interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
+
+ with FormRow():
+ checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
+ save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
+
+ with FormRow():
+ with gr.Column():
+ config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
+
+ with gr.Column():
+ with FormRow():
+ bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
+ create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
+
+ with FormRow():
+ discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
+
+ with gr.Row():
+ modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
+
+ with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
+ with gr.Group(elem_id="modelmerger_results_panel"):
+ modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
+
+ with gr.Blocks(analytics_enabled=False) as train_interface:
+ with gr.Row().style(equal_height=False):
+ gr.HTML(value="See wiki for detailed explanation.
")
+
+ with gr.Row(variant="compact").style(equal_height=False):
+ with gr.Tabs(elem_id="train_tabs"):
+
+ with gr.Tab(label="Create embedding"):
+ new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
+ initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
+ nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
+ overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
+
+ with gr.Tab(label="Create hypernetwork"):
+ new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
+ new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
+ new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
+ new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
+ new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
+ new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
+ new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
+ new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
+ overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
+
+ with gr.Tab(label="Preprocess images"):
+ process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
+ process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
+ process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
+ process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
+ preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
+
+ with gr.Row():
+ process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
+ process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
+ process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
+ process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop")
+ process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
+ process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
+
+ with gr.Row(visible=False) as process_split_extra_row:
+ process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
+ process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
+
+ with gr.Row(visible=False) as process_focal_crop_row:
+ process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
+ process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
+ process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
+ process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
+
+ with gr.Column(visible=False) as process_multicrop_col:
+ gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
+ with gr.Row():
+ process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim")
+ process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim")
+ with gr.Row():
+ process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea")
+ process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea")
+ with gr.Row():
+ process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
+ process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ with gr.Row():
+ interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
+ run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
+
+ process_split.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_split],
+ outputs=[process_split_extra_row],
+ )
+
+ process_focal_crop.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_focal_crop],
+ outputs=[process_focal_crop_row],
+ )
+
+ process_multicrop.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_multicrop],
+ outputs=[process_multicrop_col],
+ )
+
+ def get_textual_inversion_template_names():
+ return sorted([x for x in textual_inversion.textual_inversion_templates])
+
+ with gr.Tab(label="Train"):
+ gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]
")
+ with FormRow():
+ train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
+ create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
+
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
+ create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
+
+ with FormRow():
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
+
+ with FormRow():
+ clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
+ clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
+
+ with FormRow():
+ batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
+ gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
+
+ dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
+ log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
+
+ with FormRow():
+ template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
+ create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
+
+ training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
+ training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
+ varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
+ steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
+
+ with FormRow():
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
+
+ use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight")
+
+ save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
+ preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
+
+ shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
+ tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
+
+ latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
+
+ with gr.Row():
+ train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
+ interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
+
+ params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
+
+ script_callbacks.ui_train_tabs_callback(params)
+
+ with gr.Column(elem_id='ti_gallery_container'):
+ ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
+ ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
+ ti_progress = gr.HTML(elem_id="ti_progress", value="")
+ ti_outcome = gr.HTML(elem_id="ti_error", value="")
+
+ create_embedding.click(
+ fn=modules.textual_inversion.ui.create_embedding,
+ inputs=[
+ new_embedding_name,
+ initialization_text,
+ nvpt,
+ overwrite_old_embedding,
+ ],
+ outputs=[
+ train_embedding_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ create_hypernetwork.click(
+ fn=modules.hypernetworks.ui.create_hypernetwork,
+ inputs=[
+ new_hypernetwork_name,
+ new_hypernetwork_sizes,
+ overwrite_old_hypernetwork,
+ new_hypernetwork_layer_structure,
+ new_hypernetwork_activation_func,
+ new_hypernetwork_initialization_option,
+ new_hypernetwork_add_layer_norm,
+ new_hypernetwork_use_dropout,
+ new_hypernetwork_dropout_structure
+ ],
+ outputs=[
+ train_hypernetwork_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ run_preprocess.click(
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ dummy_component,
+ process_src,
+ process_dst,
+ process_width,
+ process_height,
+ preprocess_txt_action,
+ process_flip,
+ process_split,
+ process_caption,
+ process_caption_deepbooru,
+ process_split_threshold,
+ process_overlap_ratio,
+ process_focal_crop,
+ process_focal_crop_face_weight,
+ process_focal_crop_entropy_weight,
+ process_focal_crop_edges_weight,
+ process_focal_crop_debug,
+ process_multicrop,
+ process_multicrop_mindim,
+ process_multicrop_maxdim,
+ process_multicrop_minarea,
+ process_multicrop_maxarea,
+ process_multicrop_objective,
+ process_multicrop_threshold,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ],
+ )
+
+ train_embedding.click(
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ dummy_component,
+ train_embedding_name,
+ embedding_learn_rate,
+ batch_size,
+ gradient_step,
+ dataset_directory,
+ log_directory,
+ training_width,
+ training_height,
+ varsize,
+ steps,
+ clip_grad_mode,
+ clip_grad_value,
+ shuffle_tags,
+ tag_drop_out,
+ latent_sampling_method,
+ use_weight,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ save_image_with_stored_embedding,
+ preview_from_txt2img,
+ *txt2img_preview_params,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ train_hypernetwork.click(
+ fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ dummy_component,
+ train_hypernetwork_name,
+ hypernetwork_learn_rate,
+ batch_size,
+ gradient_step,
+ dataset_directory,
+ log_directory,
+ training_width,
+ training_height,
+ varsize,
+ steps,
+ clip_grad_mode,
+ clip_grad_value,
+ shuffle_tags,
+ tag_drop_out,
+ latent_sampling_method,
+ use_weight,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ preview_from_txt2img,
+ *txt2img_preview_params,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ interrupt_training.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
+ interrupt_preprocessing.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
+ def create_setting_component(key, is_quicksettings=False):
+ def fun():
+ return opts.data[key] if key in opts.data else opts.data_labels[key].default
+
+ info = opts.data_labels[key]
+ t = type(info.default)
+
+ args = info.component_args() if callable(info.component_args) else info.component_args
+
+ if info.component is not None:
+ comp = info.component
+ elif t == str:
+ comp = gr.Textbox
+ elif t == int:
+ comp = gr.Number
+ elif t == bool:
+ comp = gr.Checkbox
+ else:
+ raise Exception(f'bad options item type: {str(t)} for key {key}')
+
+ elem_id = "setting_"+key
+
+ if info.refresh is not None:
+ if is_quicksettings:
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+ create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
+ else:
+ with FormRow():
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+ create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
+ else:
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+
+ return res
+
+ components = []
+ component_dict = {}
+ shared.settings_components = component_dict
+
+ script_callbacks.ui_settings_callback()
+ opts.reorder()
+
+ def run_settings(*args):
+ changed = []
+
+ for key, value, comp in zip(opts.data_labels.keys(), args, components):
+ assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
+
+ for key, value, comp in zip(opts.data_labels.keys(), args, components):
+ if comp == dummy_component:
+ continue
+
+ if opts.set(key, value):
+ changed.append(key)
+
+ try:
+ opts.save(shared.config_filename)
+ except RuntimeError:
+ return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
+ return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
+
+ def run_settings_single(value, key):
+ if not opts.same_type(value, opts.data_labels[key].default):
+ return gr.update(visible=True), opts.dumpjson()
+
+ if not opts.set(key, value):
+ return gr.update(value=getattr(opts, key)), opts.dumpjson()
+
+ opts.save(shared.config_filename)
+
+ return get_value_for_setting(key), opts.dumpjson()
+
+ with gr.Blocks(analytics_enabled=False) as settings_interface:
+ with gr.Row():
+ with gr.Column(scale=6):
+ settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
+ with gr.Column():
+ restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
+
+ result = gr.HTML(elem_id="settings_result")
+
+ quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
+ quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
+
+ quicksettings_list = []
+
+ previous_section = None
+ current_tab = None
+ current_row = None
+ with gr.Tabs(elem_id="settings"):
+ for i, (k, item) in enumerate(opts.data_labels.items()):
+ section_must_be_skipped = item.section[0] is None
+
+ if previous_section != item.section and not section_must_be_skipped:
+ elem_id, text = item.section
+
+ if current_tab is not None:
+ current_row.__exit__()
+ current_tab.__exit__()
+
+ gr.Group()
+ current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
+ current_tab.__enter__()
+ current_row = gr.Column(variant='compact')
+ current_row.__enter__()
+
+ previous_section = item.section
+
+ if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
+ quicksettings_list.append((i, k, item))
+ components.append(dummy_component)
+ elif section_must_be_skipped:
+ components.append(dummy_component)
+ else:
+ component = create_setting_component(k)
+ component_dict[k] = component
+ components.append(component)
+
+ if current_tab is not None:
+ current_row.__exit__()
+ current_tab.__exit__()
+
+ with gr.TabItem("Actions"):
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+
+ with gr.TabItem("Licenses"):
+ gr.HTML(shared.html("licenses.html"), elem_id="licenses")
+
+ gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
+
+ request_notifications.click(
+ fn=lambda: None,
+ inputs=[],
+ outputs=[],
+ _js='function(){}'
+ )
+
+ download_localization.click(
+ fn=lambda: None,
+ inputs=[],
+ outputs=[],
+ _js='download_localization'
+ )
+
+ def reload_scripts():
+ modules.scripts.reload_script_body_only()
+ reload_javascript() # need to refresh the html page
+
+ reload_script_bodies.click(
+ fn=reload_scripts,
+ inputs=[],
+ outputs=[]
+ )
+
+ def request_restart():
+ shared.state.interrupt()
+ shared.state.need_restart = True
+
+ restart_gradio.click(
+ fn=request_restart,
+ _js='restart_reload',
+ inputs=[],
+ outputs=[],
+ )
+
+ interfaces = [
+ (txt2img_interface, "txt2img", "txt2img"),
+ (img2img_interface, "img2img", "img2img"),
+ (extras_interface, "Extras", "extras"),
+ (pnginfo_interface, "PNG Info", "pnginfo"),
+ (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
+ (train_interface, "Train", "ti"),
+ ]
+
+ css = ""
+
+ for cssfile in modules.scripts.list_files_with_name("style.css"):
+ if not os.path.isfile(cssfile):
+ continue
+
+ with open(cssfile, "r", encoding="utf8") as file:
+ css += file.read() + "\n"
+
+ if os.path.exists(os.path.join(data_path, "user.css")):
+ with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file:
+ css += file.read() + "\n"
+
+ if not cmd_opts.no_progressbar_hiding:
+ css += css_hide_progressbar
+
+ interfaces += script_callbacks.ui_tabs_callback()
+ interfaces += [(settings_interface, "Settings", "settings")]
+
+ extensions_interface = ui_extensions.create_ui()
+ interfaces += [(extensions_interface, "Extensions", "extensions")]
+
+ with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
+ with gr.Row(elem_id="quicksettings", variant="compact"):
+ for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
+ component = create_setting_component(k, is_quicksettings=True)
+ component_dict[k] = component
+
+ parameters_copypaste.connect_paste_params_buttons()
+
+ with gr.Tabs(elem_id="tabs") as tabs:
+ for interface, label, ifid in interfaces:
+ with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
+ interface.render()
+
+ if os.path.exists(os.path.join(script_path, "notification.mp3")):
+ audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
+
+ footer = shared.html("footer.html")
+ footer = footer.format(versions=versions_html())
+ gr.HTML(footer, elem_id="footer")
+
+ text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
+ settings_submit.click(
+ fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
+ inputs=components,
+ outputs=[text_settings, result],
+ )
+
+ for i, k, item in quicksettings_list:
+ component = component_dict[k]
+
+ component.change(
+ fn=lambda value, k=k: run_settings_single(value, key=k),
+ inputs=[component],
+ outputs=[component, text_settings],
+ )
+
+ text_settings.change(
+ fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
+ inputs=[],
+ outputs=[image_cfg_scale],
+ )
+
+ button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
+ button_set_checkpoint.click(
+ fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
+ _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
+ inputs=[component_dict['sd_model_checkpoint'], dummy_component],
+ outputs=[component_dict['sd_model_checkpoint'], text_settings],
+ )
+
+ component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
+
+ def get_settings_values():
+ return [get_value_for_setting(key) for key in component_keys]
+
+ demo.load(
+ fn=get_settings_values,
+ inputs=[],
+ outputs=[component_dict[k] for k in component_keys],
+ )
+
+ def modelmerger(*args):
+ try:
+ results = modules.extras.run_modelmerger(*args)
+ except Exception as e:
+ print("Error loading/saving model file:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ modules.sd_models.list_models() # to remove the potentially missing models from the list
+ return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
+ return results
+
+ modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
+ modelmerger_merge.click(
+ fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
+ _js='modelmerger',
+ inputs=[
+ dummy_component,
+ primary_model_name,
+ secondary_model_name,
+ tertiary_model_name,
+ interp_method,
+ interp_amount,
+ save_as_half,
+ custom_name,
+ checkpoint_format,
+ config_source,
+ bake_in_vae,
+ discard_weights,
+ ],
+ outputs=[
+ primary_model_name,
+ secondary_model_name,
+ tertiary_model_name,
+ component_dict['sd_model_checkpoint'],
+ modelmerger_result,
+ ]
+ )
+
+ ui_config_file = cmd_opts.ui_config_file
+ ui_settings = {}
+ settings_count = len(ui_settings)
+ error_loading = False
+
+ try:
+ if os.path.exists(ui_config_file):
+ with open(ui_config_file, "r", encoding="utf8") as file:
+ ui_settings = json.load(file)
+ except Exception:
+ error_loading = True
+ print("Error loading settings:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def loadsave(path, x):
+ def apply_field(obj, field, condition=None, init_field=None):
+ key = path + "/" + field
+
+ if getattr(obj, 'custom_script_source', None) is not None:
+ key = 'customscript/' + obj.custom_script_source + '/' + key
+
+ if getattr(obj, 'do_not_save_to_config', False):
+ return
+
+ saved_value = ui_settings.get(key, None)
+ if saved_value is None:
+ ui_settings[key] = getattr(obj, field)
+ elif condition and not condition(saved_value):
+ pass
+
+ # this warning is generally not useful;
+ # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
+ else:
+ setattr(obj, field, saved_value)
+ if init_field is not None:
+ init_field(saved_value)
+
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
+ apply_field(x, 'visible')
+
+ if type(x) == gr.Slider:
+ apply_field(x, 'value')
+ apply_field(x, 'minimum')
+ apply_field(x, 'maximum')
+ apply_field(x, 'step')
+
+ if type(x) == gr.Radio:
+ apply_field(x, 'value', lambda val: val in x.choices)
+
+ if type(x) == gr.Checkbox:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Textbox:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Number:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Dropdown:
+ def check_dropdown(val):
+ if getattr(x, 'multiselect', False):
+ return all([value in x.choices for value in val])
+ else:
+ return val in x.choices
+
+ apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
+
+ visit(txt2img_interface, loadsave, "txt2img")
+ visit(img2img_interface, loadsave, "img2img")
+ visit(extras_interface, loadsave, "extras")
+ visit(modelmerger_interface, loadsave, "modelmerger")
+ visit(train_interface, loadsave, "train")
+
+ if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
+ with open(ui_config_file, "w", encoding="utf8") as file:
+ json.dump(ui_settings, file, indent=4)
+
+ # Required as a workaround for change() event not triggering when loading values from ui-config.json
+ interp_description.value = update_interp_description(interp_method.value)
+
+ return demo
+
+
+def reload_javascript():
+ head = f'\n'
+
+ inline = f"{localization.localization_js(shared.opts.localization)};"
+ if cmd_opts.theme is not None:
+ inline += f"set_theme('{cmd_opts.theme}');"
+
+ for script in modules.scripts.list_scripts("javascript", ".js"):
+ head += f'\n'
+
+ head += f'\n'
+
+ def template_response(*args, **kwargs):
+ res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
+ res.body = res.body.replace(b'', f'{head}'.encode("utf8"))
+ res.init_headers()
+ return res
+
+ gradio.routes.templates.TemplateResponse = template_response
+
+
+if not hasattr(shared, 'GradioTemplateResponseOriginal'):
+ shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
+
+
+def versions_html():
+ import torch
+ import launch
+
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
+ commit = launch.commit_hash()
+ short_commit = commit[0:8]
+
+ if shared.xformers_available:
+ import xformers
+ xformers_version = xformers.__version__
+ else:
+ xformers_version = "N/A"
+
+ return f"""
+python: {python_version}
+ •
+torch: {getattr(torch, '__long_version__',torch.__version__)}
+ •
+xformers: {xformers_version}
+ •
+gradio: {gr.__version__}
+ •
+commit: {short_commit}
+ •
+checkpoint: N/A
+"""
diff --git a/modules/ui_common.py b/modules/ui_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..21ebb0955eec9604d6d41c22eeb1541f70a82580
--- /dev/null
+++ b/modules/ui_common.py
@@ -0,0 +1,206 @@
+import json
+import html
+import os
+import platform
+import sys
+
+import gradio as gr
+import subprocess as sp
+
+from modules import call_queue, shared
+from modules.generation_parameters_copypaste import image_from_url_text
+import modules.images
+
+folder_symbol = '\U0001f4c2' # 📂
+
+
+def update_generation_info(generation_info, html_info, img_index):
+ try:
+ generation_info = json.loads(generation_info)
+ if img_index < 0 or img_index >= len(generation_info["infotexts"]):
+ return html_info, gr.update()
+ return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
+ except Exception:
+ pass
+ # if the json parse or anything else fails, just return the old html_info
+ return html_info, gr.update()
+
+
+def plaintext_to_html(text):
+ text = "" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "
"
+ return text
+
+
+def save_files(js_data, images, do_make_zip, index):
+ import csv
+ filenames = []
+ fullfns = []
+
+ #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
+ class MyObject:
+ def __init__(self, d=None):
+ if d is not None:
+ for key, value in d.items():
+ setattr(self, key, value)
+
+ data = json.loads(js_data)
+
+ p = MyObject(data)
+ path = shared.opts.outdir_save
+ save_to_dirs = shared.opts.use_save_to_dirs_for_ui
+ extension: str = shared.opts.samples_format
+ start_index = 0
+
+ if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
+
+ images = [images[index]]
+ start_index = index
+
+ os.makedirs(shared.opts.outdir_save, exist_ok=True)
+
+ with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
+ at_start = file.tell() == 0
+ writer = csv.writer(file)
+ if at_start:
+ writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
+
+ for image_index, filedata in enumerate(images, start_index):
+ image = image_from_url_text(filedata)
+
+ is_grid = image_index < p.index_of_first_image
+ i = 0 if is_grid else (image_index - p.index_of_first_image)
+
+ fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
+
+ filename = os.path.relpath(fullfn, path)
+ filenames.append(filename)
+ fullfns.append(fullfn)
+ if txt_fullfn:
+ filenames.append(os.path.basename(txt_fullfn))
+ fullfns.append(txt_fullfn)
+
+ writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
+
+ # Make Zip
+ if do_make_zip:
+ zip_filepath = os.path.join(path, "images.zip")
+
+ from zipfile import ZipFile
+ with ZipFile(zip_filepath, "w") as zip_file:
+ for i in range(len(fullfns)):
+ with open(fullfns[i], mode="rb") as f:
+ zip_file.writestr(filenames[i], f.read())
+ fullfns.insert(0, zip_filepath)
+
+ return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
+
+
+def create_output_panel(tabname, outdir):
+ from modules import shared
+ import modules.generation_parameters_copypaste as parameters_copypaste
+
+ def open_folder(f):
+ if not os.path.exists(f):
+ print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
+ return
+ elif not os.path.isdir(f):
+ print(f"""
+WARNING
+An open_folder request was made with an argument that is not a folder.
+This could be an error or a malicious attempt to run code on your computer.
+Requested path was: {f}
+""", file=sys.stderr)
+ return
+
+ if not shared.cmd_opts.hide_ui_dir_config:
+ path = os.path.normpath(f)
+ if platform.system() == "Windows":
+ os.startfile(path)
+ elif platform.system() == "Darwin":
+ sp.Popen(["open", path])
+ elif "microsoft-standard-WSL2" in platform.uname().release:
+ sp.Popen(["wsl-open", path])
+ else:
+ sp.Popen(["xdg-open", path])
+
+ with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
+ with gr.Group(elem_id=f"{tabname}_gallery_container"):
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
+
+ generation_info = None
+ with gr.Column():
+ with gr.Row(elem_id=f"image_buttons_{tabname}"):
+ open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
+
+ if tabname != "extras":
+ save = gr.Button('Save', elem_id=f'save_{tabname}')
+ save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
+
+ buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
+
+ open_folder_button.click(
+ fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
+ inputs=[],
+ outputs=[],
+ )
+
+ if tabname != "extras":
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
+
+ with gr.Group():
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
+ if tabname == 'txt2img' or tabname == 'img2img':
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
+ generation_info_button.click(
+ fn=update_generation_info,
+ _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
+ inputs=[generation_info, html_info, html_info],
+ outputs=[html_info, html_info],
+ )
+
+ save.click(
+ fn=call_queue.wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ],
+ show_progress=False,
+ )
+
+ save_zip.click(
+ fn=call_queue.wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ]
+ )
+
+ else:
+ html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ for paste_tabname, paste_button in buttons.items():
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery
+ ))
+
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
diff --git a/modules/ui_components.py b/modules/ui_components.py
new file mode 100644
index 0000000000000000000000000000000000000000..d239d3f70938942f625f5f49e9398fcde10016bf
--- /dev/null
+++ b/modules/ui_components.py
@@ -0,0 +1,58 @@
+import gradio as gr
+
+
+class ToolButton(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(variant="tool", **kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
+class ToolButtonTop(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(variant="tool-top", **kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
+class FormRow(gr.Row, gr.components.FormComponent):
+ """Same as gr.Row but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "row"
+
+
+class FormGroup(gr.Group, gr.components.FormComponent):
+ """Same as gr.Row but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "group"
+
+
+class FormHTML(gr.HTML, gr.components.FormComponent):
+ """Same as gr.HTML but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "html"
+
+
+class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
+ """Same as gr.ColorPicker but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "colorpicker"
+
+
+class DropdownMulti(gr.Dropdown):
+ """Same as gr.Dropdown but always multiselect"""
+ def __init__(self, **kwargs):
+ super().__init__(multiselect=True, **kwargs)
+
+ def get_block_name(self):
+ return "dropdown"
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
new file mode 100644
index 0000000000000000000000000000000000000000..12f395cef3a6e1e0ad28d1577c0208794b897335
--- /dev/null
+++ b/modules/ui_extensions.py
@@ -0,0 +1,354 @@
+import json
+import os.path
+import shutil
+import sys
+import time
+import traceback
+
+import git
+
+import gradio as gr
+import html
+import shutil
+import errno
+
+from modules import extensions, shared, paths
+from modules.call_queue import wrap_gradio_gpu_call
+
+available_extensions = {"extensions": []}
+
+
+def check_access():
+ assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
+
+
+def apply_and_restart(disable_list, update_list):
+ check_access()
+
+ disabled = json.loads(disable_list)
+ assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
+
+ update = json.loads(update_list)
+ assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
+
+ update = set(update)
+
+ for ext in extensions.extensions:
+ if ext.name not in update:
+ continue
+
+ try:
+ ext.fetch_and_reset_hard()
+ except Exception:
+ print(f"Error getting updates for {ext.name}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ shared.opts.disabled_extensions = disabled
+ shared.opts.save(shared.config_filename)
+
+ shared.state.interrupt()
+ shared.state.need_restart = True
+
+
+def check_updates(id_task, disable_list):
+ check_access()
+
+ disabled = json.loads(disable_list)
+ assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
+
+ exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
+ shared.state.job_count = len(exts)
+
+ for ext in exts:
+ shared.state.textinfo = ext.name
+
+ try:
+ ext.check_updates()
+ except Exception:
+ print(f"Error checking updates for {ext.name}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ shared.state.nextjob()
+
+ return extension_table(), ""
+
+
+def extension_table():
+ code = f"""
+
+ """
+
+ return code
+
+
+def normalize_git_url(url):
+ if url is None:
+ return ""
+
+ url = url.replace(".git", "")
+ return url
+
+
+def install_extension_from_url(dirname, url):
+ check_access()
+
+ assert url, 'No URL specified'
+
+ if dirname is None or dirname == "":
+ *parts, last_part = url.split('/')
+ last_part = normalize_git_url(last_part)
+
+ dirname = last_part
+
+ target_dir = os.path.join(extensions.extensions_dir, dirname)
+ assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
+
+ normalized_url = normalize_git_url(url)
+ assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
+
+ tmpdir = os.path.join(paths.data_path, "tmp", dirname)
+
+ try:
+ shutil.rmtree(tmpdir, True)
+
+ repo = git.Repo.clone_from(url, tmpdir)
+ repo.remote().fetch()
+
+ try:
+ os.rename(tmpdir, target_dir)
+ except OSError as err:
+ # TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
+ # Shouldn't cause any new issues at least but we probably want to handle it there too.
+ if err.errno == errno.EXDEV:
+ # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
+ # Since we can't use a rename, do the slower but more versitile shutil.move()
+ shutil.move(tmpdir, target_dir)
+ else:
+ # Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
+ raise(err)
+
+ import launch
+ launch.run_extension_installer(target_dir)
+
+ extensions.list_extensions()
+ return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
+ finally:
+ shutil.rmtree(tmpdir, True)
+
+
+def install_extension_from_index(url, hide_tags, sort_column):
+ ext_table, message = install_extension_from_url(None, url)
+
+ code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
+
+ return code, ext_table, message
+
+
+def refresh_available_extensions(url, hide_tags, sort_column):
+ global available_extensions
+
+ import urllib.request
+ with urllib.request.urlopen(url) as response:
+ text = response.read()
+
+ available_extensions = json.loads(text)
+
+ code, tags = refresh_available_extensions_from_data(hide_tags, sort_column)
+
+ return url, code, gr.CheckboxGroup.update(choices=tags), ''
+
+
+def refresh_available_extensions_for_tags(hide_tags, sort_column):
+ code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
+
+ return code, ''
+
+
+sort_ordering = [
+ # (reverse, order_by_function)
+ (True, lambda x: x.get('added', 'z')),
+ (False, lambda x: x.get('added', 'z')),
+ (False, lambda x: x.get('name', 'z')),
+ (True, lambda x: x.get('name', 'z')),
+ (False, lambda x: 'z'),
+]
+
+
+def refresh_available_extensions_from_data(hide_tags, sort_column):
+ extlist = available_extensions["extensions"]
+ installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
+
+ tags = available_extensions.get("tags", {})
+ tags_to_hide = set(hide_tags)
+ hidden = 0
+
+ code = f"""
+
+
+
+ Extension |
+ Description |
+ Action |
+
+
+
+ """
+
+ sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0]
+
+ for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
+ name = ext.get("name", "noname")
+ added = ext.get('added', 'unknown')
+ url = ext.get("url", None)
+ description = ext.get("description", "")
+ extension_tags = ext.get("tags", [])
+
+ if url is None:
+ continue
+
+ existing = installed_extension_urls.get(normalize_git_url(url), None)
+ extension_tags = extension_tags + ["installed"] if existing else extension_tags
+
+ if len([x for x in extension_tags if x in tags_to_hide]) > 0:
+ hidden += 1
+ continue
+
+ install_code = f""""""
+
+ tags_text = ", ".join([f"{x}" for x in extension_tags])
+
+ code += f"""
+
+ {html.escape(name)} {tags_text} |
+ {html.escape(description)} Added: {html.escape(added)} |
+ {install_code} |
+
+
+ """
+
+ for tag in [x for x in extension_tags if x not in tags]:
+ tags[tag] = tag
+
+ code += """
+
+
+ """
+
+ if hidden > 0:
+ code += f"Extension hidden: {hidden}
"
+
+ return code, list(tags)
+
+
+def create_ui():
+ import modules.ui
+
+ with gr.Blocks(analytics_enabled=False) as ui:
+ with gr.Tabs(elem_id="tabs_extensions") as tabs:
+ with gr.TabItem("Installed"):
+
+ with gr.Row(elem_id="extensions_installed_top"):
+ apply = gr.Button(value="Apply and restart UI", variant="primary")
+ check = gr.Button(value="Check for updates")
+ extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
+ extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
+
+ info = gr.HTML()
+ extensions_table = gr.HTML(lambda: extension_table())
+
+ apply.click(
+ fn=apply_and_restart,
+ _js="extensions_apply",
+ inputs=[extensions_disabled_list, extensions_update_list],
+ outputs=[],
+ )
+
+ check.click(
+ fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]),
+ _js="extensions_check",
+ inputs=[info, extensions_disabled_list],
+ outputs=[extensions_table, info],
+ )
+
+ with gr.TabItem("Available"):
+ with gr.Row():
+ refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
+ available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
+ extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
+ install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
+
+ with gr.Row():
+ hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
+ sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
+
+ install_result = gr.HTML()
+ available_extensions_table = gr.HTML()
+
+ refresh_available_extensions_button.click(
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
+ inputs=[available_extensions_index, hide_tags, sort_column],
+ outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
+ )
+
+ install_extension_button.click(
+ fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
+ inputs=[extension_to_install, hide_tags, sort_column],
+ outputs=[available_extensions_table, extensions_table, install_result],
+ )
+
+ hide_tags.change(
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
+ inputs=[hide_tags, sort_column],
+ outputs=[available_extensions_table, install_result]
+ )
+
+ sort_column.change(
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
+ inputs=[hide_tags, sort_column],
+ outputs=[available_extensions_table, install_result]
+ )
+
+ with gr.TabItem("Install from URL"):
+ install_url = gr.Text(label="URL for extension's git repository")
+ install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
+ install_button = gr.Button(value="Install", variant="primary")
+ install_result = gr.HTML(elem_id="extension_install_result")
+
+ install_button.click(
+ fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
+ inputs=[install_dirname, install_url],
+ outputs=[extensions_table, install_result],
+ )
+
+ return ui
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..e788319dad4aba4484437af166b7c0fccd651798
--- /dev/null
+++ b/modules/ui_extra_networks.py
@@ -0,0 +1,251 @@
+import glob
+import os.path
+import urllib.parse
+from pathlib import Path
+
+from modules import shared
+import gradio as gr
+import json
+import html
+
+from modules.generation_parameters_copypaste import image_from_url_text
+
+extra_pages = []
+allowed_dirs = set()
+
+
+def register_page(page):
+ """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
+
+ extra_pages.append(page)
+ allowed_dirs.clear()
+ allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
+
+
+def add_pages_to_demo(app):
+ def fetch_file(filename: str = ""):
+ from starlette.responses import FileResponse
+
+ if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
+ raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
+
+ ext = os.path.splitext(filename)[1].lower()
+ if ext not in (".png", ".jpg"):
+ raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
+
+ # would profit from returning 304
+ return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
+
+ app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
+
+
+class ExtraNetworksPage:
+ def __init__(self, title):
+ self.title = title
+ self.name = title.lower()
+ self.card_page = shared.html("extra-networks-card.html")
+ self.allow_negative_prompt = False
+
+ def refresh(self):
+ pass
+
+ def link_preview(self, filename):
+ return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
+
+ def search_terms_from_path(self, filename, possible_directories=None):
+ abspath = os.path.abspath(filename)
+
+ for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
+ parentdir = os.path.abspath(parentdir)
+ if abspath.startswith(parentdir):
+ return abspath[len(parentdir):].replace('\\', '/')
+
+ return ""
+
+ def create_html(self, tabname):
+ view = shared.opts.extra_networks_default_view
+ items_html = ''
+
+ subdirs = {}
+ for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
+ for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
+ if not os.path.isdir(x):
+ continue
+
+ subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
+ while subdir.startswith("/"):
+ subdir = subdir[1:]
+
+ is_empty = len(os.listdir(x)) == 0
+ if not is_empty and not subdir.endswith("/"):
+ subdir = subdir + "/"
+
+ subdirs[subdir] = 1
+
+ if subdirs:
+ subdirs = {"": 1, **subdirs}
+
+ subdirs_html = "".join([f"""
+
+""" for subdir in subdirs])
+
+ for item in self.list_items():
+ items_html += self.create_html_for_item(item, tabname)
+
+ if items_html == '':
+ dirs = "".join([f"{x}" for x in self.allowed_directories_for_previews()])
+ items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
+
+ self_name_id = self.name.replace(" ", "_")
+
+ res = f"""
+
+
+"""
+
+ return res
+
+ def list_items(self):
+ raise NotImplementedError()
+
+ def allowed_directories_for_previews(self):
+ return []
+
+ def create_html_for_item(self, item, tabname):
+ preview = item.get("preview", None)
+
+ onclick = item.get("onclick", None)
+ if onclick is None:
+ onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+
+ args = {
+ "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
+ "prompt": item.get("prompt", None),
+ "tabname": json.dumps(tabname),
+ "local_preview": json.dumps(item["local_preview"]),
+ "name": item["name"],
+ "card_clicked": onclick,
+ "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
+ "search_term": item.get("search_term", ""),
+ }
+
+ return self.card_page.format(**args)
+
+
+def intialize():
+ extra_pages.clear()
+
+
+class ExtraNetworksUi:
+ def __init__(self):
+ self.pages = None
+ self.stored_extra_pages = None
+
+ self.button_save_preview = None
+ self.preview_target_filename = None
+
+ self.tabname = None
+
+
+def pages_in_preferred_order(pages):
+ tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
+
+ def tab_name_score(name):
+ name = name.lower()
+ for i, possible_match in enumerate(tab_order):
+ if possible_match in name:
+ return i
+
+ return len(pages)
+
+ tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
+
+ return sorted(pages, key=lambda x: tab_scores[x.name])
+
+
+def create_ui(container, button, tabname):
+ ui = ExtraNetworksUi()
+ ui.pages = []
+ ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
+ ui.tabname = tabname
+
+ with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
+ for page in ui.stored_extra_pages:
+ with gr.Tab(page.title):
+ page_elem = gr.HTML(page.create_html(ui.tabname))
+ ui.pages.append(page_elem)
+
+ filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
+ button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
+ button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
+
+ ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
+ ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
+
+ def toggle_visibility(is_visible):
+ is_visible = not is_visible
+ return is_visible, gr.update(visible=is_visible)
+
+ state_visible = gr.State(value=False)
+ button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
+ button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
+
+ def refresh():
+ res = []
+
+ for pg in ui.stored_extra_pages:
+ pg.refresh()
+ res.append(pg.create_html(ui.tabname))
+
+ return res
+
+ button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
+
+ return ui
+
+
+def path_is_parent(parent_path, child_path):
+ parent_path = os.path.abspath(parent_path)
+ child_path = os.path.abspath(child_path)
+
+ return child_path.startswith(parent_path)
+
+
+def setup_ui(ui, gallery):
+ def save_preview(index, images, filename):
+ if len(images) == 0:
+ print("There is no image in gallery to save as a preview.")
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
+
+ index = int(index)
+ index = 0 if index < 0 else index
+ index = len(images) - 1 if index >= len(images) else index
+
+ img_info = images[index if index >= 0 else 0]
+ image = image_from_url_text(img_info)
+
+ is_allowed = False
+ for extra_page in ui.stored_extra_pages:
+ if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
+ is_allowed = True
+ break
+
+ assert is_allowed, f'writing to {filename} is not allowed'
+
+ image.save(filename)
+
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
+
+ ui.button_save_preview.click(
+ fn=save_preview,
+ _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
+ inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
+ outputs=[*ui.pages]
+ )
+
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
new file mode 100644
index 0000000000000000000000000000000000000000..766e2c2499f0a1866e88424a319b99a9df973fc4
--- /dev/null
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -0,0 +1,39 @@
+import html
+import json
+import os
+import urllib.parse
+
+from modules import shared, ui_extra_networks, sd_models
+
+
+class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Checkpoints')
+
+ def refresh(self):
+ shared.refresh_checkpoints()
+
+ def list_items(self):
+ checkpoint: sd_models.CheckpointInfo
+ for name, checkpoint in sd_models.checkpoints_list.items():
+ path, ext = os.path.splitext(checkpoint.filename)
+ previews = [path + ".png", path + ".preview.png"]
+
+ preview = None
+ for file in previews:
+ if os.path.isfile(file):
+ preview = self.link_preview(file)
+ break
+
+ yield {
+ "name": checkpoint.name_for_extra,
+ "filename": path,
+ "preview": preview,
+ "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
+ "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
+ "local_preview": path + ".png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
+
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fe6516a71e7ca5203fdacec3f750494d5650efd
--- /dev/null
+++ b/modules/ui_extra_networks_hypernets.py
@@ -0,0 +1,36 @@
+import json
+import os
+
+from modules import shared, ui_extra_networks
+
+
+class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Hypernetworks')
+
+ def refresh(self):
+ shared.reload_hypernetworks()
+
+ def list_items(self):
+ for name, path in shared.hypernetworks.items():
+ path, ext = os.path.splitext(path)
+ previews = [path + ".png", path + ".preview.png"]
+
+ preview = None
+ for file in previews:
+ if os.path.isfile(file):
+ preview = self.link_preview(file)
+ break
+
+ yield {
+ "name": name,
+ "filename": path,
+ "preview": preview,
+ "search_term": self.search_terms_from_path(path),
+ "prompt": json.dumps(f""),
+ "local_preview": path + ".png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return [shared.cmd_opts.hypernetwork_dir]
+
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..47491b2f09ccc960e2e237097f9d9e78075d25c0
--- /dev/null
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -0,0 +1,34 @@
+import json
+import os
+
+from modules import ui_extra_networks, sd_hijack
+
+
+class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Textual Inversion')
+ self.allow_negative_prompt = True
+
+ def refresh(self):
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+
+ def list_items(self):
+ for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
+ path, ext = os.path.splitext(embedding.filename)
+ preview_file = path + ".preview.png"
+
+ preview = None
+ if os.path.isfile(preview_file):
+ preview = self.link_preview(preview_file)
+
+ yield {
+ "name": embedding.name,
+ "filename": embedding.filename,
+ "preview": preview,
+ "search_term": self.search_terms_from_path(embedding.filename),
+ "prompt": json.dumps(embedding.name),
+ "local_preview": path + ".preview.png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..7789347028ecb309607038d0bc79eff934f45711
--- /dev/null
+++ b/modules/ui_postprocessing.py
@@ -0,0 +1,57 @@
+import gradio as gr
+from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
+import modules.generation_parameters_copypaste as parameters_copypaste
+
+
+def create_ui():
+ tab_index = gr.State(value=0)
+
+ with gr.Row().style(equal_height=False, variant='compact'):
+ with gr.Column(variant='compact'):
+ with gr.Tabs(elem_id="mode_extras"):
+ with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single:
+ extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
+
+ with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
+ image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
+
+ with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
+ show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
+
+ submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
+
+ script_inputs = scripts.scripts_postproc.setup_ui()
+
+ with gr.Column():
+ result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
+
+ tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
+ tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
+ tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
+
+ submit.click(
+ fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
+ inputs=[
+ tab_index,
+ extras_image,
+ image_batch,
+ extras_batch_input_dir,
+ extras_batch_output_dir,
+ show_extras_results,
+ *script_inputs
+ ],
+ outputs=[
+ result_images,
+ html_info_x,
+ html_info,
+ ]
+ )
+
+ parameters_copypaste.add_paste_fields("extras", extras_image, None)
+
+ extras_image.change(
+ fn=scripts.scripts_postproc.image_changed,
+ inputs=[], outputs=[]
+ )
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
new file mode 100644
index 0000000000000000000000000000000000000000..126f73a21d71070887fd094beaf0fe6d7e12df9c
--- /dev/null
+++ b/modules/ui_tempdir.py
@@ -0,0 +1,82 @@
+import os
+import tempfile
+from collections import namedtuple
+from pathlib import Path
+
+import gradio as gr
+
+from PIL import PngImagePlugin
+
+from modules import shared
+
+
+Savedfile = namedtuple("Savedfile", ["name"])
+
+
+def register_tmp_file(gradio, filename):
+ if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
+ gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
+
+ if hasattr(gradio, 'temp_dirs'): # gradio 3.9
+ gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
+
+
+def check_tmp_file(gradio, filename):
+ if hasattr(gradio, 'temp_file_sets'):
+ return any([filename in fileset for fileset in gradio.temp_file_sets])
+
+ if hasattr(gradio, 'temp_dirs'):
+ return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
+
+ return False
+
+
+def save_pil_to_file(pil_image, dir=None):
+ already_saved_as = getattr(pil_image, 'already_saved_as', None)
+ if already_saved_as and os.path.isfile(already_saved_as):
+ register_tmp_file(shared.demo, already_saved_as)
+
+ file_obj = Savedfile(already_saved_as)
+ return file_obj
+
+ if shared.opts.temp_dir != "":
+ dir = shared.opts.temp_dir
+
+ use_metadata = False
+ metadata = PngImagePlugin.PngInfo()
+ for key, value in pil_image.info.items():
+ if isinstance(key, str) and isinstance(value, str):
+ metadata.add_text(key, value)
+ use_metadata = True
+
+ file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
+ pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
+ return file_obj
+
+
+# override save to file function so that it also writes PNG info
+gr.processing_utils.save_pil_to_file = save_pil_to_file
+
+
+def on_tmpdir_changed():
+ if shared.opts.temp_dir == "" or shared.demo is None:
+ return
+
+ os.makedirs(shared.opts.temp_dir, exist_ok=True)
+
+ register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
+
+
+def cleanup_tmpdr():
+ temp_dir = shared.opts.temp_dir
+ if temp_dir == "" or not os.path.isdir(temp_dir):
+ return
+
+ for root, dirs, files in os.walk(temp_dir, topdown=False):
+ for name in files:
+ _, extension = os.path.splitext(name)
+ if extension != ".png":
+ continue
+
+ filename = os.path.join(root, name)
+ os.remove(filename)
diff --git a/modules/upscaler.py b/modules/upscaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2eaa7308af0091b6e8f407e889b2e446679e149
--- /dev/null
+++ b/modules/upscaler.py
@@ -0,0 +1,145 @@
+import os
+from abc import abstractmethod
+
+import PIL
+import numpy as np
+import torch
+from PIL import Image
+
+import modules.shared
+from modules import modelloader, shared
+
+LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
+
+
+class Upscaler:
+ name = None
+ model_path = None
+ model_name = None
+ model_url = None
+ enable = True
+ filter = None
+ model = None
+ user_path = None
+ scalers: []
+ tile = True
+
+ def __init__(self, create_dirs=False):
+ self.mod_pad_h = None
+ self.tile_size = modules.shared.opts.ESRGAN_tile
+ self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
+ self.device = modules.shared.device
+ self.img = None
+ self.output = None
+ self.scale = 1
+ self.half = not modules.shared.cmd_opts.no_half
+ self.pre_pad = 0
+ self.mod_scale = None
+
+ if self.model_path is None and self.name:
+ self.model_path = os.path.join(shared.models_path, self.name)
+ if self.model_path and create_dirs:
+ os.makedirs(self.model_path, exist_ok=True)
+
+ try:
+ import cv2
+ self.can_tile = True
+ except:
+ pass
+
+ @abstractmethod
+ def do_upscale(self, img: PIL.Image, selected_model: str):
+ return img
+
+ def upscale(self, img: PIL.Image, scale, selected_model: str = None):
+ self.scale = scale
+ dest_w = int(img.width * scale)
+ dest_h = int(img.height * scale)
+
+ for i in range(3):
+ shape = (img.width, img.height)
+
+ img = self.do_upscale(img, selected_model)
+
+ if shape == (img.width, img.height):
+ break
+
+ if img.width >= dest_w and img.height >= dest_h:
+ break
+
+ if img.width != dest_w or img.height != dest_h:
+ img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
+
+ return img
+
+ @abstractmethod
+ def load_model(self, path: str):
+ pass
+
+ def find_models(self, ext_filter=None) -> list:
+ return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
+
+ def update_status(self, prompt):
+ print(f"\nextras: {prompt}", file=shared.progress_print_out)
+
+
+class UpscalerData:
+ name = None
+ data_path = None
+ scale: int = 4
+ scaler: Upscaler = None
+ model: None
+
+ def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
+ self.name = name
+ self.data_path = path
+ self.local_data_path = path
+ self.scaler = upscaler
+ self.scale = scale
+ self.model = model
+
+
+class UpscalerNone(Upscaler):
+ name = "None"
+ scalers = []
+
+ def load_model(self, path):
+ pass
+
+ def do_upscale(self, img, selected_model=None):
+ return img
+
+ def __init__(self, dirname=None):
+ super().__init__(False)
+ self.scalers = [UpscalerData("None", None, self)]
+
+
+class UpscalerLanczos(Upscaler):
+ scalers = []
+
+ def do_upscale(self, img, selected_model=None):
+ return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
+
+ def load_model(self, _):
+ pass
+
+ def __init__(self, dirname=None):
+ super().__init__(False)
+ self.name = "Lanczos"
+ self.scalers = [UpscalerData("Lanczos", None, self)]
+
+
+class UpscalerNearest(Upscaler):
+ scalers = []
+
+ def do_upscale(self, img, selected_model=None):
+ return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
+
+ def load_model(self, _):
+ pass
+
+ def __init__(self, dirname=None):
+ super().__init__(False)
+ self.name = "Nearest"
+ self.scalers = [UpscalerData("Nearest", None, self)]
diff --git a/modules/xlmr.py b/modules/xlmr.py
new file mode 100644
index 0000000000000000000000000000000000000000..beab3fdf55e7bcffd96f3b36679e7a90c0f390dc
--- /dev/null
+++ b/modules/xlmr.py
@@ -0,0 +1,137 @@
+from transformers import BertPreTrainedModel,BertModel,BertConfig
+import torch.nn as nn
+import torch
+from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
+from transformers import XLMRobertaModel,XLMRobertaTokenizer
+from typing import Optional
+
+class BertSeriesConfig(BertConfig):
+ def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
+
+ super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+class RobertaSeriesConfig(XLMRobertaConfig):
+ def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+
+class BertSeriesModelWithTransformation(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+ config_class = BertSeriesConfig
+
+ def __init__(self, config=None, **kargs):
+ # modify initialization for autoloading
+ if config is None:
+ config = XLMRobertaConfig()
+ config.attention_probs_dropout_prob= 0.1
+ config.bos_token_id=0
+ config.eos_token_id=2
+ config.hidden_act='gelu'
+ config.hidden_dropout_prob=0.1
+ config.hidden_size=1024
+ config.initializer_range=0.02
+ config.intermediate_size=4096
+ config.layer_norm_eps=1e-05
+ config.max_position_embeddings=514
+
+ config.num_attention_heads=16
+ config.num_hidden_layers=24
+ config.output_past=True
+ config.pad_token_id=1
+ config.position_embedding_type= "absolute"
+
+ config.type_vocab_size= 1
+ config.use_cache=True
+ config.vocab_size= 250002
+ config.project_dim = 768
+ config.learn_encoder = False
+ super().__init__(config)
+ self.roberta = XLMRobertaModel(config)
+ self.transformation = nn.Linear(config.hidden_size,config.project_dim)
+ self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
+ self.pooler = lambda x: x[:,0]
+ self.post_init()
+
+ def encode(self,c):
+ device = next(self.parameters()).device
+ text = self.tokenizer(c,
+ truncation=True,
+ max_length=77,
+ return_length=False,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt")
+ text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
+ text["attention_mask"] = torch.tensor(
+ text['attention_mask']).to(device)
+ features = self(**text)
+ return features['projection_state']
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) :
+ r"""
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+
+ outputs = self.roberta(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ )
+
+ # last module outputs
+ sequence_output = outputs[0]
+
+
+ # project every module
+ sequence_output_ln = self.pre_LN(sequence_output)
+
+ # pooler
+ pooler_output = self.pooler(sequence_output_ln)
+ pooler_output = self.transformation(pooler_output)
+ projection_state = self.transformation(outputs.last_hidden_state)
+
+ return {
+ 'pooler_output':pooler_output,
+ 'last_hidden_state':outputs.last_hidden_state,
+ 'hidden_states':outputs.hidden_states,
+ 'attentions':outputs.attentions,
+ 'projection_state':projection_state,
+ 'sequence_out': sequence_output
+ }
+
+
+class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
+ base_model_prefix = 'roberta'
+ config_class= RobertaSeriesConfig
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7e2e15b32f6801fa9bb9db46ff604670ad55ab6c
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,30 @@
+blendmodes==2022
+transformers==4.25.1
+accelerate==0.12.0
+basicsr==1.4.2
+gfpgan==1.3.8
+gradio==3.16.2
+numpy==1.23.3
+Pillow==9.4.0
+realesrgan==0.3.0
+torch
+omegaconf==2.2.3
+pytorch_lightning==1.7.6
+scikit-image==0.19.2
+fonts
+font-roboto
+timm==0.6.7
+piexif==1.1.3
+einops==0.4.1
+jsonmerge==1.8.0
+clean-fid==0.1.29
+resize-right==0.0.2
+torchdiffeq==0.2.3
+kornia==0.6.7
+lark==1.1.2
+inflection==0.5.1
+GitPython==3.1.27
+torchsde==0.2.5
+safetensors==0.2.7
+httpcore<=0.15
+fastapi==0.90.1
diff --git a/scripts/custom_code.py b/scripts/custom_code.py
new file mode 100644
index 0000000000000000000000000000000000000000..935c544e3e8b9a9a282108563d4e00074502829a
--- /dev/null
+++ b/scripts/custom_code.py
@@ -0,0 +1,41 @@
+import modules.scripts as scripts
+import gradio as gr
+
+from modules.processing import Processed
+from modules.shared import opts, cmd_opts, state
+
+class Script(scripts.Script):
+
+ def title(self):
+ return "Custom code"
+
+ def show(self, is_img2img):
+ return cmd_opts.allow_code
+
+ def ui(self, is_img2img):
+ code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code"))
+
+ return [code]
+
+
+ def run(self, p, code):
+ assert cmd_opts.allow_code, '--allow-code option must be enabled'
+
+ display_result_data = [[], -1, ""]
+
+ def display(imgs, s=display_result_data[1], i=display_result_data[2]):
+ display_result_data[0] = imgs
+ display_result_data[1] = s
+ display_result_data[2] = i
+
+ from types import ModuleType
+ compiled = compile(code, '', 'exec')
+ module = ModuleType("testmodule")
+ module.__dict__.update(globals())
+ module.p = p
+ module.display = display
+ exec(compiled, module.__dict__)
+
+ return Processed(p, *display_result_data)
+
+
\ No newline at end of file
diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py
new file mode 100644
index 0000000000000000000000000000000000000000..65b61533929a018f0cb97a89266154bf569cd40e
--- /dev/null
+++ b/scripts/img2imgalt.py
@@ -0,0 +1,216 @@
+from collections import namedtuple
+
+import numpy as np
+from tqdm import trange
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common
+from modules.processing import Processed
+from modules.shared import opts, cmd_opts, state
+
+import torch
+import k_diffusion as K
+
+from PIL import Image
+from torch import autocast
+from einops import rearrange, repeat
+
+
+def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
+ x = p.init_latent
+
+ s_in = x.new_ones([x.shape[0]])
+ dnw = K.external.CompVisDenoiser(shared.sd_model)
+ sigmas = dnw.get_sigmas(steps).flip(0)
+
+ shared.state.sampling_steps = steps
+
+ for i in trange(1, len(sigmas)):
+ shared.state.sampling_step += 1
+
+ x_in = torch.cat([x] * 2)
+ sigma_in = torch.cat([sigmas[i] * s_in] * 2)
+ cond_in = torch.cat([uncond, cond])
+
+ image_conditioning = torch.cat([p.image_conditioning] * 2)
+ cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
+
+ c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
+ t = dnw.sigma_to_t(sigma_in)
+
+ eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
+ denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
+
+ denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
+
+ d = (x - denoised) / sigmas[i]
+ dt = sigmas[i] - sigmas[i - 1]
+
+ x = x + d * dt
+
+ sd_samplers_common.store_latent(x)
+
+ # This shouldn't be necessary, but solved some VRAM issues
+ del x_in, sigma_in, cond_in, c_out, c_in, t,
+ del eps, denoised_uncond, denoised_cond, denoised, d, dt
+
+ shared.state.nextjob()
+
+ return x / x.std()
+
+
+Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt", "sigma_adjustment"])
+
+
+# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
+def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
+ x = p.init_latent
+
+ s_in = x.new_ones([x.shape[0]])
+ dnw = K.external.CompVisDenoiser(shared.sd_model)
+ sigmas = dnw.get_sigmas(steps).flip(0)
+
+ shared.state.sampling_steps = steps
+
+ for i in trange(1, len(sigmas)):
+ shared.state.sampling_step += 1
+
+ x_in = torch.cat([x] * 2)
+ sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
+ cond_in = torch.cat([uncond, cond])
+
+ image_conditioning = torch.cat([p.image_conditioning] * 2)
+ cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
+
+ c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
+
+ if i == 1:
+ t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
+ else:
+ t = dnw.sigma_to_t(sigma_in)
+
+ eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
+ denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
+
+ denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
+
+ if i == 1:
+ d = (x - denoised) / (2 * sigmas[i])
+ else:
+ d = (x - denoised) / sigmas[i - 1]
+
+ dt = sigmas[i] - sigmas[i - 1]
+ x = x + d * dt
+
+ sd_samplers_common.store_latent(x)
+
+ # This shouldn't be necessary, but solved some VRAM issues
+ del x_in, sigma_in, cond_in, c_out, c_in, t,
+ del eps, denoised_uncond, denoised_cond, denoised, d, dt
+
+ shared.state.nextjob()
+
+ return x / sigmas[-1]
+
+
+class Script(scripts.Script):
+ def __init__(self):
+ self.cache = None
+
+ def title(self):
+ return "img2img alternative test"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ info = gr.Markdown('''
+ * `CFG Scale` should be 2 or lower.
+ ''')
+
+ override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=self.elem_id("override_sampler"))
+
+ override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=self.elem_id("override_prompt"))
+ original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=self.elem_id("original_prompt"))
+ original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=self.elem_id("original_negative_prompt"))
+
+ override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=self.elem_id("override_steps"))
+ st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=self.elem_id("st"))
+
+ override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=self.elem_id("override_strength"))
+
+ cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=self.elem_id("cfg"))
+ randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=self.elem_id("randomness"))
+ sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment"))
+
+ return [
+ info,
+ override_sampler,
+ override_prompt, original_prompt, original_negative_prompt,
+ override_steps, st,
+ override_strength,
+ cfg, randomness, sigma_adjustment,
+ ]
+
+ def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
+ # Override
+ if override_sampler:
+ p.sampler_name = "Euler"
+ if override_prompt:
+ p.prompt = original_prompt
+ p.negative_prompt = original_negative_prompt
+ if override_steps:
+ p.steps = st
+ if override_strength:
+ p.denoising_strength = 1.0
+
+ def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
+ lat = (p.init_latent.cpu().numpy() * 10).astype(int)
+
+ same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
+ and self.cache.original_prompt == original_prompt \
+ and self.cache.original_negative_prompt == original_negative_prompt \
+ and self.cache.sigma_adjustment == sigma_adjustment
+ same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100
+
+ if same_everything:
+ rec_noise = self.cache.noise
+ else:
+ shared.state.job_count += 1
+ cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
+ uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt])
+ if sigma_adjustment:
+ rec_noise = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st)
+ else:
+ rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
+ self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
+
+ rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
+
+ combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
+
+ sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
+
+ sigmas = sampler.model_wrap.get_sigmas(p.steps)
+
+ noise_dt = combined_noise - (p.init_latent / sigmas[0])
+
+ p.seed = p.seed + 1
+
+ return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
+
+ p.sample = sample_extra
+
+ p.extra_generation_params["Decode prompt"] = original_prompt
+ p.extra_generation_params["Decode negative prompt"] = original_negative_prompt
+ p.extra_generation_params["Decode CFG scale"] = cfg
+ p.extra_generation_params["Decode steps"] = st
+ p.extra_generation_params["Randomness"] = randomness
+ p.extra_generation_params["Sigma Adjustment"] = sigma_adjustment
+
+ processed = processing.process_images(p)
+
+ return processed
+
diff --git a/scripts/loopback.py b/scripts/loopback.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3ac98240ea022b028e056c21626b453ae19543d
--- /dev/null
+++ b/scripts/loopback.py
@@ -0,0 +1,98 @@
+import numpy as np
+from tqdm import trange
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules import processing, shared, sd_samplers, images
+from modules.processing import Processed
+from modules.sd_samplers import samplers
+from modules.shared import opts, cmd_opts, state
+from modules import deepbooru
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Loopback"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops"))
+ denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=self.elem_id("denoising_strength_change_factor"))
+ append_interrogation = gr.Dropdown(label="Append interrogated prompt at each iteration", choices=["None", "CLIP", "DeepBooru"], value="None")
+
+ return [loops, denoising_strength_change_factor, append_interrogation]
+
+ def run(self, p, loops, denoising_strength_change_factor, append_interrogation):
+ processing.fix_seed(p)
+ batch_count = p.n_iter
+ p.extra_generation_params = {
+ "Denoising strength change factor": denoising_strength_change_factor,
+ }
+
+ p.batch_size = 1
+ p.n_iter = 1
+
+ output_images, info = None, None
+ initial_seed = None
+ initial_info = None
+
+ grids = []
+ all_images = []
+ original_init_image = p.init_images
+ original_prompt = p.prompt
+ state.job_count = loops * batch_count
+
+ initial_color_corrections = [processing.setup_color_correction(p.init_images[0])]
+
+ for n in range(batch_count):
+ history = []
+
+ # Reset to original init image at the start of each batch
+ p.init_images = original_init_image
+
+ for i in range(loops):
+ p.n_iter = 1
+ p.batch_size = 1
+ p.do_not_save_grid = True
+
+ if opts.img2img_color_correction:
+ p.color_corrections = initial_color_corrections
+
+ if append_interrogation != "None":
+ p.prompt = original_prompt + ", " if original_prompt != "" else ""
+ if append_interrogation == "CLIP":
+ p.prompt += shared.interrogator.interrogate(p.init_images[0])
+ elif append_interrogation == "DeepBooru":
+ p.prompt += deepbooru.model.tag(p.init_images[0])
+
+ state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}"
+
+ processed = processing.process_images(p)
+
+ if initial_seed is None:
+ initial_seed = processed.seed
+ initial_info = processed.info
+
+ init_img = processed.images[0]
+
+ p.init_images = [init_img]
+ p.seed = processed.seed + 1
+ p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1)
+ history.append(processed.images[0])
+
+ grid = images.image_grid(history, rows=1)
+ if opts.grid_save:
+ images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
+
+ grids.append(grid)
+ all_images += history
+
+ if opts.return_grid:
+ all_images = grids + all_images
+
+ processed = Processed(p, all_images, initial_seed, initial_info)
+
+ return processed
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d80b46cd3263ef0905514a761bb473441d8a1e7
--- /dev/null
+++ b/scripts/outpainting_mk_2.py
@@ -0,0 +1,283 @@
+import math
+
+import numpy as np
+import skimage
+
+import modules.scripts as scripts
+import gradio as gr
+from PIL import Image, ImageDraw
+
+from modules import images, processing, devices
+from modules.processing import Processed, process_images
+from modules.shared import opts, cmd_opts, state
+
+
+# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
+def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
+ # helper fft routines that keep ortho normalization and auto-shift before and after fft
+ def _fft2(data):
+ if data.ndim > 2: # has channels
+ out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
+ for c in range(data.shape[2]):
+ c_data = data[:, :, c]
+ out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho")
+ out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
+ else: # one channel
+ out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
+ out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
+ out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
+
+ return out_fft
+
+ def _ifft2(data):
+ if data.ndim > 2: # has channels
+ out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
+ for c in range(data.shape[2]):
+ c_data = data[:, :, c]
+ out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho")
+ out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
+ else: # one channel
+ out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
+ out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
+ out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
+
+ return out_ifft
+
+ def _get_gaussian_window(width, height, std=3.14, mode=0):
+ window_scale_x = float(width / min(width, height))
+ window_scale_y = float(height / min(width, height))
+
+ window = np.zeros((width, height))
+ x = (np.arange(width) / width * 2. - 1.) * window_scale_x
+ for y in range(height):
+ fy = (y / height * 2. - 1.) * window_scale_y
+ if mode == 0:
+ window[:, y] = np.exp(-(x ** 2 + fy ** 2) * std)
+ else:
+ window[:, y] = (1 / ((x ** 2 + 1.) * (fy ** 2 + 1.))) ** (std / 3.14) # hey wait a minute that's not gaussian
+
+ return window
+
+ def _get_masked_window_rgb(np_mask_grey, hardness=1.):
+ np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
+ if hardness != 1.:
+ hardened = np_mask_grey[:] ** hardness
+ else:
+ hardened = np_mask_grey[:]
+ for c in range(3):
+ np_mask_rgb[:, :, c] = hardened[:]
+ return np_mask_rgb
+
+ width = _np_src_image.shape[0]
+ height = _np_src_image.shape[1]
+ num_channels = _np_src_image.shape[2]
+
+ np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
+ np_mask_grey = (np.sum(np_mask_rgb, axis=2) / 3.)
+ img_mask = np_mask_grey > 1e-6
+ ref_mask = np_mask_grey < 1e-3
+
+ windowed_image = _np_src_image * (1. - _get_masked_window_rgb(np_mask_grey))
+ windowed_image /= np.max(windowed_image)
+ windowed_image += np.average(_np_src_image) * np_mask_rgb # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
+
+ src_fft = _fft2(windowed_image) # get feature statistics from masked src img
+ src_dist = np.absolute(src_fft)
+ src_phase = src_fft / src_dist
+
+ # create a generator with a static seed to make outpainting deterministic / only follow global seed
+ rng = np.random.default_rng(0)
+
+ noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
+ noise_rgb = rng.random((width, height, num_channels))
+ noise_grey = (np.sum(noise_rgb, axis=2) / 3.)
+ noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
+ for c in range(num_channels):
+ noise_rgb[:, :, c] += (1. - color_variation) * noise_grey
+
+ noise_fft = _fft2(noise_rgb)
+ for c in range(num_channels):
+ noise_fft[:, :, c] *= noise_window
+ noise_rgb = np.real(_ifft2(noise_fft))
+ shaped_noise_fft = _fft2(noise_rgb)
+ shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
+
+ brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now
+ contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
+
+ # scikit-image is used for histogram matching, very convenient!
+ shaped_noise = np.real(_ifft2(shaped_noise_fft))
+ shaped_noise -= np.min(shaped_noise)
+ shaped_noise /= np.max(shaped_noise)
+ shaped_noise[img_mask, :] = skimage.exposure.match_histograms(shaped_noise[img_mask, :] ** 1., contrast_adjusted_np_src[ref_mask, :], channel_axis=1)
+ shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
+
+ matched_noise = shaped_noise[:]
+
+ return np.clip(matched_noise, 0., 1.)
+
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Outpainting mk2"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ if not is_img2img:
+ return None
+
+ info = gr.HTML("Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8
")
+
+ pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur"))
+ direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
+ noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q"))
+ color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation"))
+
+ return [info, pixels, mask_blur, direction, noise_q, color_variation]
+
+ def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation):
+ initial_seed_and_info = [None, None]
+
+ process_width = p.width
+ process_height = p.height
+
+ p.mask_blur = mask_blur*4
+ p.inpaint_full_res = False
+ p.inpainting_fill = 1
+ p.do_not_save_samples = True
+ p.do_not_save_grid = True
+
+ left = pixels if "left" in direction else 0
+ right = pixels if "right" in direction else 0
+ up = pixels if "up" in direction else 0
+ down = pixels if "down" in direction else 0
+
+ init_img = p.init_images[0]
+ target_w = math.ceil((init_img.width + left + right) / 64) * 64
+ target_h = math.ceil((init_img.height + up + down) / 64) * 64
+
+ if left > 0:
+ left = left * (target_w - init_img.width) // (left + right)
+
+ if right > 0:
+ right = target_w - init_img.width - left
+
+ if up > 0:
+ up = up * (target_h - init_img.height) // (up + down)
+
+ if down > 0:
+ down = target_h - init_img.height - up
+
+ def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
+ is_horiz = is_left or is_right
+ is_vert = is_top or is_bottom
+ pixels_horiz = expand_pixels if is_horiz else 0
+ pixels_vert = expand_pixels if is_vert else 0
+
+ images_to_process = []
+ output_images = []
+ for n in range(count):
+ res_w = init[n].width + pixels_horiz
+ res_h = init[n].height + pixels_vert
+ process_res_w = math.ceil(res_w / 64) * 64
+ process_res_h = math.ceil(res_h / 64) * 64
+
+ img = Image.new("RGB", (process_res_w, process_res_h))
+ img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
+ mask = Image.new("RGB", (process_res_w, process_res_h), "white")
+ draw = ImageDraw.Draw(mask)
+ draw.rectangle((
+ expand_pixels + mask_blur if is_left else 0,
+ expand_pixels + mask_blur if is_top else 0,
+ mask.width - expand_pixels - mask_blur if is_right else res_w,
+ mask.height - expand_pixels - mask_blur if is_bottom else res_h,
+ ), fill="black")
+
+ np_image = (np.asarray(img) / 255.0).astype(np.float64)
+ np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
+ noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
+ output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))
+
+ target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
+ target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
+ p.width = target_width if is_horiz else img.width
+ p.height = target_height if is_vert else img.height
+
+ crop_region = (
+ 0 if is_left else output_images[n].width - target_width,
+ 0 if is_top else output_images[n].height - target_height,
+ target_width if is_left else output_images[n].width,
+ target_height if is_top else output_images[n].height,
+ )
+ mask = mask.crop(crop_region)
+ p.image_mask = mask
+
+ image_to_process = output_images[n].crop(crop_region)
+ images_to_process.append(image_to_process)
+
+ p.init_images = images_to_process
+
+ latent_mask = Image.new("RGB", (p.width, p.height), "white")
+ draw = ImageDraw.Draw(latent_mask)
+ draw.rectangle((
+ expand_pixels + mask_blur * 2 if is_left else 0,
+ expand_pixels + mask_blur * 2 if is_top else 0,
+ mask.width - expand_pixels - mask_blur * 2 if is_right else res_w,
+ mask.height - expand_pixels - mask_blur * 2 if is_bottom else res_h,
+ ), fill="black")
+ p.latent_mask = latent_mask
+
+ proc = process_images(p)
+
+ if initial_seed_and_info[0] is None:
+ initial_seed_and_info[0] = proc.seed
+ initial_seed_and_info[1] = proc.info
+
+ for n in range(count):
+ output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
+ output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
+
+ return output_images
+
+ batch_count = p.n_iter
+ batch_size = p.batch_size
+ p.n_iter = 1
+ state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
+ all_processed_images = []
+
+ for i in range(batch_count):
+ imgs = [init_img] * batch_size
+ state.job = f"Batch {i + 1} out of {batch_count}"
+
+ if left > 0:
+ imgs = expand(imgs, batch_size, left, is_left=True)
+ if right > 0:
+ imgs = expand(imgs, batch_size, right, is_right=True)
+ if up > 0:
+ imgs = expand(imgs, batch_size, up, is_top=True)
+ if down > 0:
+ imgs = expand(imgs, batch_size, down, is_bottom=True)
+
+ all_processed_images += imgs
+
+ all_images = all_processed_images
+
+ combined_grid_image = images.image_grid(all_processed_images)
+ unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
+ if opts.return_grid and not unwanted_grid_because_of_img_count:
+ all_images = [combined_grid_image] + all_processed_images
+
+ res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
+
+ if opts.samples_save:
+ for img in all_processed_images:
+ images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
+
+ if opts.grid_save and not unwanted_grid_because_of_img_count:
+ images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
+
+ return res
diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..d39f61c1073376eae210d955ac1e9eba836402da
--- /dev/null
+++ b/scripts/poor_mans_outpainting.py
@@ -0,0 +1,146 @@
+import math
+
+import modules.scripts as scripts
+import gradio as gr
+from PIL import Image, ImageDraw
+
+from modules import images, processing, devices
+from modules.processing import Processed, process_images
+from modules.shared import opts, cmd_opts, state
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Poor man's outpainting"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ if not is_img2img:
+ return None
+
+ pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur"))
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill"))
+ direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
+
+ return [pixels, mask_blur, inpainting_fill, direction]
+
+ def run(self, p, pixels, mask_blur, inpainting_fill, direction):
+ initial_seed = None
+ initial_info = None
+
+ p.mask_blur = mask_blur * 2
+ p.inpainting_fill = inpainting_fill
+ p.inpaint_full_res = False
+
+ left = pixels if "left" in direction else 0
+ right = pixels if "right" in direction else 0
+ up = pixels if "up" in direction else 0
+ down = pixels if "down" in direction else 0
+
+ init_img = p.init_images[0]
+ target_w = math.ceil((init_img.width + left + right) / 64) * 64
+ target_h = math.ceil((init_img.height + up + down) / 64) * 64
+
+ if left > 0:
+ left = left * (target_w - init_img.width) // (left + right)
+ if right > 0:
+ right = target_w - init_img.width - left
+
+ if up > 0:
+ up = up * (target_h - init_img.height) // (up + down)
+
+ if down > 0:
+ down = target_h - init_img.height - up
+
+ img = Image.new("RGB", (target_w, target_h))
+ img.paste(init_img, (left, up))
+
+ mask = Image.new("L", (img.width, img.height), "white")
+ draw = ImageDraw.Draw(mask)
+ draw.rectangle((
+ left + (mask_blur * 2 if left > 0 else 0),
+ up + (mask_blur * 2 if up > 0 else 0),
+ mask.width - right - (mask_blur * 2 if right > 0 else 0),
+ mask.height - down - (mask_blur * 2 if down > 0 else 0)
+ ), fill="black")
+
+ latent_mask = Image.new("L", (img.width, img.height), "white")
+ latent_draw = ImageDraw.Draw(latent_mask)
+ latent_draw.rectangle((
+ left + (mask_blur//2 if left > 0 else 0),
+ up + (mask_blur//2 if up > 0 else 0),
+ mask.width - right - (mask_blur//2 if right > 0 else 0),
+ mask.height - down - (mask_blur//2 if down > 0 else 0)
+ ), fill="black")
+
+ devices.torch_gc()
+
+ grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
+ grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
+ grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
+
+ p.n_iter = 1
+ p.batch_size = 1
+ p.do_not_save_grid = True
+ p.do_not_save_samples = True
+
+ work = []
+ work_mask = []
+ work_latent_mask = []
+ work_results = []
+
+ for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles):
+ for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask):
+ x, w = tiledata[0:2]
+
+ if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:
+ continue
+
+ work.append(tiledata[2])
+ work_mask.append(tiledata_mask[2])
+ work_latent_mask.append(tiledata_latent_mask[2])
+
+ batch_count = len(work)
+ print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.")
+
+ state.job_count = batch_count
+
+ for i in range(batch_count):
+ p.init_images = [work[i]]
+ p.image_mask = work_mask[i]
+ p.latent_mask = work_latent_mask[i]
+
+ state.job = f"Batch {i + 1} out of {batch_count}"
+ processed = process_images(p)
+
+ if initial_seed is None:
+ initial_seed = processed.seed
+ initial_info = processed.info
+
+ p.seed = processed.seed + 1
+ work_results += processed.images
+
+
+ image_index = 0
+ for y, h, row in grid.tiles:
+ for tiledata in row:
+ x, w = tiledata[0:2]
+
+ if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:
+ continue
+
+ tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
+ image_index += 1
+
+ combined_image = images.combine_grid(grid)
+
+ if opts.samples_save:
+ images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info, p=p)
+
+ processed = Processed(p, [combined_image], initial_seed, initial_info)
+
+ return processed
+
diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e337ec41ffffe11fd88fced0ff4f6338d959571
--- /dev/null
+++ b/scripts/postprocessing_codeformer.py
@@ -0,0 +1,36 @@
+from PIL import Image
+import numpy as np
+
+from modules import scripts_postprocessing, codeformer_model
+import gradio as gr
+
+from modules.ui_components import FormRow
+
+
+class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
+ name = "CodeFormer"
+ order = 3000
+
+ def ui(self):
+ with FormRow():
+ codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility")
+ codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
+
+ return {
+ "codeformer_visibility": codeformer_visibility,
+ "codeformer_weight": codeformer_weight,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight):
+ if codeformer_visibility == 0:
+ return
+
+ restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
+ res = Image.fromarray(restored_img)
+
+ if codeformer_visibility < 1.0:
+ res = Image.blend(pp.image, res, codeformer_visibility)
+
+ pp.image = res
+ pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3)
+ pp.info["CodeFormer weight"] = round(codeformer_weight, 3)
diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f7c2baaa28333958818d332324b34bcb8bce3ca
--- /dev/null
+++ b/scripts/postprocessing_gfpgan.py
@@ -0,0 +1,33 @@
+from PIL import Image
+import numpy as np
+
+from modules import scripts_postprocessing, gfpgan_model
+import gradio as gr
+
+from modules.ui_components import FormRow
+
+
+class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
+ name = "GFPGAN"
+ order = 2000
+
+ def ui(self):
+ with FormRow():
+ gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility")
+
+ return {
+ "gfpgan_visibility": gfpgan_visibility,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility):
+ if gfpgan_visibility == 0:
+ return
+
+ restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))
+ res = Image.fromarray(restored_img)
+
+ if gfpgan_visibility < 1.0:
+ res = Image.blend(pp.image, res, gfpgan_visibility)
+
+ pp.image = res
+ pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3)
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccec72fcbc72eeffbe24a659bf53ecba71162391
--- /dev/null
+++ b/scripts/postprocessing_upscale.py
@@ -0,0 +1,131 @@
+from PIL import Image
+import numpy as np
+
+from modules import scripts_postprocessing, shared
+import gradio as gr
+
+from modules.ui_components import FormRow
+
+
+upscale_cache = {}
+
+
+class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
+ name = "Upscale"
+ order = 1000
+
+ def ui(self):
+ selected_tab = gr.State(value=0)
+
+ with gr.Tabs(elem_id="extras_resize_mode"):
+ with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
+ upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
+
+ with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
+ with FormRow():
+ upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
+ upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
+ upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
+
+ with FormRow():
+ extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
+
+ with FormRow():
+ extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
+ extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
+
+ tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
+ tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
+
+ return {
+ "upscale_mode": selected_tab,
+ "upscale_by": upscaling_resize,
+ "upscale_to_width": upscaling_resize_w,
+ "upscale_to_height": upscaling_resize_h,
+ "upscale_crop": upscaling_crop,
+ "upscaler_1_name": extras_upscaler_1,
+ "upscaler_2_name": extras_upscaler_2,
+ "upscaler_2_visibility": extras_upscaler_2_visibility,
+ }
+
+ def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop):
+ if upscale_mode == 1:
+ upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height)
+ info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}"
+ else:
+ info["Postprocess upscale by"] = upscale_by
+
+ cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
+ cached_image = upscale_cache.pop(cache_key, None)
+
+ if cached_image is not None:
+ image = cached_image
+ else:
+ image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path)
+
+ upscale_cache[cache_key] = image
+ if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache:
+ upscale_cache.pop(next(iter(upscale_cache), None), None)
+
+ if upscale_mode == 1 and upscale_crop:
+ cropped = Image.new("RGB", (upscale_to_width, upscale_to_height))
+ cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2))
+ image = cropped
+ info["Postprocess crop to"] = f"{image.width}x{image.height}"
+
+ return image
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
+ if upscaler_1_name == "None":
+ upscaler_1_name = None
+
+ upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None)
+ assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}'
+
+ if not upscaler1:
+ return
+
+ if upscaler_2_name == "None":
+ upscaler_2_name = None
+
+ upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None)
+ assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}'
+
+ upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
+ pp.info[f"Postprocess upscaler"] = upscaler1.name
+
+ if upscaler2 and upscaler_2_visibility > 0:
+ second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
+ upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility)
+
+ pp.info[f"Postprocess upscaler 2"] = upscaler2.name
+
+ pp.image = upscaled_image
+
+ def image_changed(self):
+ upscale_cache.clear()
+
+
+class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
+ name = "Simple Upscale"
+ order = 900
+
+ def ui(self):
+ with FormRow():
+ upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
+ upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2)
+
+ return {
+ "upscale_by": upscale_by,
+ "upscaler_name": upscaler_name,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
+ if upscaler_name is None or upscaler_name == "None":
+ return
+
+ upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None)
+ assert upscaler1, f'could not find upscaler named {upscaler_name}'
+
+ pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
+ pp.info[f"Postprocess upscaler"] = upscaler1.name
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..51c70998866d4b0853a46e4de73d86c3d9ec9b93
--- /dev/null
+++ b/scripts/prompt_matrix.py
@@ -0,0 +1,111 @@
+import math
+from collections import namedtuple
+from copy import copy
+import random
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules import images
+from modules.processing import process_images, Processed
+from modules.shared import opts, cmd_opts, state
+import modules.sd_samplers
+
+
+def draw_xy_grid(xs, ys, x_label, y_label, cell):
+ res = []
+
+ ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
+ hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
+
+ first_processed = None
+
+ state.job_count = len(xs) * len(ys)
+
+ for iy, y in enumerate(ys):
+ for ix, x in enumerate(xs):
+ state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
+
+ processed = cell(x, y)
+ if first_processed is None:
+ first_processed = processed
+
+ res.append(processed.images[0])
+
+ grid = images.image_grid(res, rows=len(ys))
+ grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
+
+ first_processed.images = [grid]
+
+ return first_processed
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Prompt matrix"
+
+ def ui(self, is_img2img):
+ gr.HTML('
')
+ with gr.Row():
+ with gr.Column():
+ put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
+ different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
+ with gr.Column():
+ prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive")
+ variations_delimiter = gr.Radio(["comma", "space"], label="Select joining char", elem_id=self.elem_id("variations_delimiter"), value="comma")
+ with gr.Column():
+ margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
+
+ return [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size]
+
+ def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size):
+ modules.processing.fix_seed(p)
+ # Raise error if promp type is not positive or negative
+ if prompt_type not in ["positive", "negative"]:
+ raise ValueError(f"Unknown prompt type {prompt_type}")
+ # Raise error if variations delimiter is not comma or space
+ if variations_delimiter not in ["comma", "space"]:
+ raise ValueError(f"Unknown variations delimiter {variations_delimiter}")
+
+ prompt = p.prompt if prompt_type == "positive" else p.negative_prompt
+ original_prompt = prompt[0] if type(prompt) == list else prompt
+ positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
+
+ delimiter = ", " if variations_delimiter == "comma" else " "
+
+ all_prompts = []
+ prompt_matrix_parts = original_prompt.split("|")
+ combination_count = 2 ** (len(prompt_matrix_parts) - 1)
+ for combination_num in range(combination_count):
+ selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
+
+ if put_at_start:
+ selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
+ else:
+ selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
+
+ all_prompts.append(delimiter.join(selected_prompts))
+
+ p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
+ p.do_not_save_grid = True
+
+ print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
+
+ if prompt_type == "positive":
+ p.prompt = all_prompts
+ else:
+ p.negative_prompt = all_prompts
+ p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))]
+ p.prompt_for_display = positive_prompt
+ processed = process_images(p)
+
+ grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
+ grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[1].height, prompt_matrix_parts, margin_size)
+ processed.images.insert(0, grid)
+ processed.index_of_first_image = 1
+ processed.infotexts.insert(0, processed.infotexts[0])
+
+ if opts.grid_save:
+ images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p)
+
+ return processed
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
new file mode 100644
index 0000000000000000000000000000000000000000..17c9c967ffeeb3f538c6f95d93ae79c32ca17828
--- /dev/null
+++ b/scripts/prompts_from_file.py
@@ -0,0 +1,177 @@
+import copy
+import math
+import os
+import random
+import sys
+import traceback
+import shlex
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules import sd_samplers
+from modules.processing import Processed, process_images
+from PIL import Image
+from modules.shared import opts, cmd_opts, state
+
+
+def process_string_tag(tag):
+ return tag
+
+
+def process_int_tag(tag):
+ return int(tag)
+
+
+def process_float_tag(tag):
+ return float(tag)
+
+
+def process_boolean_tag(tag):
+ return True if (tag == "true") else False
+
+
+prompt_tags = {
+ "sd_model": None,
+ "outpath_samples": process_string_tag,
+ "outpath_grids": process_string_tag,
+ "prompt_for_display": process_string_tag,
+ "prompt": process_string_tag,
+ "negative_prompt": process_string_tag,
+ "styles": process_string_tag,
+ "seed": process_int_tag,
+ "subseed_strength": process_float_tag,
+ "subseed": process_int_tag,
+ "seed_resize_from_h": process_int_tag,
+ "seed_resize_from_w": process_int_tag,
+ "sampler_index": process_int_tag,
+ "sampler_name": process_string_tag,
+ "batch_size": process_int_tag,
+ "n_iter": process_int_tag,
+ "steps": process_int_tag,
+ "cfg_scale": process_float_tag,
+ "width": process_int_tag,
+ "height": process_int_tag,
+ "restore_faces": process_boolean_tag,
+ "tiling": process_boolean_tag,
+ "do_not_save_samples": process_boolean_tag,
+ "do_not_save_grid": process_boolean_tag
+}
+
+
+def cmdargs(line):
+ args = shlex.split(line)
+ pos = 0
+ res = {}
+
+ while pos < len(args):
+ arg = args[pos]
+
+ assert arg.startswith("--"), f'must start with "--": {arg}'
+ assert pos+1 < len(args), f'missing argument for command line option {arg}'
+
+ tag = arg[2:]
+
+ if tag == "prompt" or tag == "negative_prompt":
+ pos += 1
+ prompt = args[pos]
+ pos += 1
+ while pos < len(args) and not args[pos].startswith("--"):
+ prompt += " "
+ prompt += args[pos]
+ pos += 1
+ res[tag] = prompt
+ continue
+
+
+ func = prompt_tags.get(tag, None)
+ assert func, f'unknown commandline option: {arg}'
+
+ val = args[pos+1]
+ if tag == "sampler_name":
+ val = sd_samplers.samplers_map.get(val.lower(), None)
+
+ res[tag] = func(val)
+
+ pos += 2
+
+ return res
+
+
+def load_prompt_file(file):
+ if file is None:
+ lines = []
+ else:
+ lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")]
+
+ return None, "\n".join(lines), gr.update(lines=7)
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Prompts from file or textbox"
+
+ def ui(self, is_img2img):
+ checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
+ checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
+
+ prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
+ file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
+
+ file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt])
+
+ # We start at one line. When the text changes, we jump to seven lines, or two lines if no \n.
+ # We don't shrink back to 1, because that causes the control to ignore [enter], and it may
+ # be unclear to the user that shift-enter is needed.
+ prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt])
+ return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
+
+ def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
+ lines = [x.strip() for x in prompt_txt.splitlines()]
+ lines = [x for x in lines if len(x) > 0]
+
+ p.do_not_save_grid = True
+
+ job_count = 0
+ jobs = []
+
+ for line in lines:
+ if "--" in line:
+ try:
+ args = cmdargs(line)
+ except Exception:
+ print(f"Error parsing line {line} as commandline:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ args = {"prompt": line}
+ else:
+ args = {"prompt": line}
+
+ job_count += args.get("n_iter", p.n_iter)
+
+ jobs.append(args)
+
+ print(f"Will process {len(lines)} lines in {job_count} jobs.")
+ if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1:
+ p.seed = int(random.randrange(4294967294))
+
+ state.job_count = job_count
+
+ images = []
+ all_prompts = []
+ infotexts = []
+ for n, args in enumerate(jobs):
+ state.job = f"{state.job_no + 1} out of {state.job_count}"
+
+ copy_p = copy.copy(p)
+ for k, v in args.items():
+ setattr(copy_p, k, v)
+
+ proc = process_images(copy_p)
+ images += proc.images
+
+ if checkbox_iterate:
+ p.seed = p.seed + (p.batch_size * p.n_iter)
+ all_prompts += proc.all_prompts
+ infotexts += proc.infotexts
+
+ return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts)
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd64d7d385b34f960eebd7aec3233a84ea90609a
--- /dev/null
+++ b/scripts/sd_upscale.py
@@ -0,0 +1,101 @@
+import math
+
+import modules.scripts as scripts
+import gradio as gr
+from PIL import Image
+
+from modules import processing, shared, sd_samplers, images, devices
+from modules.processing import Processed
+from modules.shared import opts, cmd_opts, state
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "SD upscale"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ info = gr.HTML("Will upscale the image by the selected scale factor; use width and height sliders to set tile size
")
+ overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap"))
+ scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor"))
+ upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index"))
+
+ return [info, overlap, upscaler_index, scale_factor]
+
+ def run(self, p, _, overlap, upscaler_index, scale_factor):
+ if isinstance(upscaler_index, str):
+ upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower())
+ processing.fix_seed(p)
+ upscaler = shared.sd_upscalers[upscaler_index]
+
+ p.extra_generation_params["SD upscale overlap"] = overlap
+ p.extra_generation_params["SD upscale upscaler"] = upscaler.name
+
+ initial_info = None
+ seed = p.seed
+
+ init_img = p.init_images[0]
+ init_img = images.flatten(init_img, opts.img2img_background_color)
+
+ if upscaler.name != "None":
+ img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)
+ else:
+ img = init_img
+
+ devices.torch_gc()
+
+ grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap)
+
+ batch_size = p.batch_size
+ upscale_count = p.n_iter
+ p.n_iter = 1
+ p.do_not_save_grid = True
+ p.do_not_save_samples = True
+
+ work = []
+
+ for y, h, row in grid.tiles:
+ for tiledata in row:
+ work.append(tiledata[2])
+
+ batch_count = math.ceil(len(work) / batch_size)
+ state.job_count = batch_count * upscale_count
+
+ print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.")
+
+ result_images = []
+ for n in range(upscale_count):
+ start_seed = seed + n
+ p.seed = start_seed
+
+ work_results = []
+ for i in range(batch_count):
+ p.batch_size = batch_size
+ p.init_images = work[i * batch_size:(i + 1) * batch_size]
+
+ state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}"
+ processed = processing.process_images(p)
+
+ if initial_info is None:
+ initial_info = processed.info
+
+ p.seed = processed.seed + 1
+ work_results += processed.images
+
+ image_index = 0
+ for y, h, row in grid.tiles:
+ for tiledata in row:
+ tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
+ image_index += 1
+
+ combined_image = images.combine_grid(grid)
+ result_images.append(combined_image)
+
+ if opts.samples_save:
+ images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p)
+
+ processed = Processed(p, result_images, seed, initial_info)
+
+ return processed
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
new file mode 100644
index 0000000000000000000000000000000000000000..e457d53de2ade37a53cc200d6902323d3d6f25ce
--- /dev/null
+++ b/scripts/xyz_grid.py
@@ -0,0 +1,620 @@
+from collections import namedtuple
+from copy import copy
+from itertools import permutations, chain
+import random
+import csv
+from io import StringIO
+from PIL import Image
+import numpy as np
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
+from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
+from modules.shared import opts, cmd_opts, state
+import modules.shared as shared
+import modules.sd_samplers
+import modules.sd_models
+import modules.sd_vae
+import glob
+import os
+import re
+
+from modules.ui_components import ToolButton
+
+fill_values_symbol = "\U0001f4d2" # 📒
+
+AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])
+
+
+def apply_field(field):
+ def fun(p, x, xs):
+ setattr(p, field, x)
+
+ return fun
+
+
+def apply_prompt(p, x, xs):
+ if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
+ raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
+
+ p.prompt = p.prompt.replace(xs[0], x)
+ p.negative_prompt = p.negative_prompt.replace(xs[0], x)
+
+
+def apply_order(p, x, xs):
+ token_order = []
+
+ # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
+ for token in x:
+ token_order.append((p.prompt.find(token), token))
+
+ token_order.sort(key=lambda t: t[0])
+
+ prompt_parts = []
+
+ # Split the prompt up, taking out the tokens
+ for _, token in token_order:
+ n = p.prompt.find(token)
+ prompt_parts.append(p.prompt[0:n])
+ p.prompt = p.prompt[n + len(token):]
+
+ # Rebuild the prompt with the tokens in the order we want
+ prompt_tmp = ""
+ for idx, part in enumerate(prompt_parts):
+ prompt_tmp += part
+ prompt_tmp += x[idx]
+ p.prompt = prompt_tmp + p.prompt
+
+
+def apply_sampler(p, x, xs):
+ sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
+ if sampler_name is None:
+ raise RuntimeError(f"Unknown sampler: {x}")
+
+ p.sampler_name = sampler_name
+
+
+def confirm_samplers(p, xs):
+ for x in xs:
+ if x.lower() not in sd_samplers.samplers_map:
+ raise RuntimeError(f"Unknown sampler: {x}")
+
+
+def apply_checkpoint(p, x, xs):
+ info = modules.sd_models.get_closet_checkpoint_match(x)
+ if info is None:
+ raise RuntimeError(f"Unknown checkpoint: {x}")
+ modules.sd_models.reload_model_weights(shared.sd_model, info)
+
+
+def confirm_checkpoints(p, xs):
+ for x in xs:
+ if modules.sd_models.get_closet_checkpoint_match(x) is None:
+ raise RuntimeError(f"Unknown checkpoint: {x}")
+
+
+def apply_clip_skip(p, x, xs):
+ opts.data["CLIP_stop_at_last_layers"] = x
+
+
+def apply_upscale_latent_space(p, x, xs):
+ if x.lower().strip() != '0':
+ opts.data["use_scale_latent_for_hires_fix"] = True
+ else:
+ opts.data["use_scale_latent_for_hires_fix"] = False
+
+
+def find_vae(name: str):
+ if name.lower() in ['auto', 'automatic']:
+ return modules.sd_vae.unspecified
+ if name.lower() == 'none':
+ return None
+ else:
+ choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
+ if len(choices) == 0:
+ print(f"No VAE found for {name}; using automatic")
+ return modules.sd_vae.unspecified
+ else:
+ return modules.sd_vae.vae_dict[choices[0]]
+
+
+def apply_vae(p, x, xs):
+ modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
+
+
+def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
+ p.styles.extend(x.split(','))
+
+
+def format_value_add_label(p, opt, x):
+ if type(x) == float:
+ x = round(x, 8)
+
+ return f"{opt.label}: {x}"
+
+
+def format_value(p, opt, x):
+ if type(x) == float:
+ x = round(x, 8)
+ return x
+
+
+def format_value_join_list(p, opt, x):
+ return ", ".join(x)
+
+
+def do_nothing(p, x, xs):
+ pass
+
+
+def format_nothing(p, opt, x):
+ return ""
+
+
+def str_permutations(x):
+ """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
+ return x
+
+
+class AxisOption:
+ def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
+ self.label = label
+ self.type = type
+ self.apply = apply
+ self.format_value = format_value
+ self.confirm = confirm
+ self.cost = cost
+ self.choices = choices
+
+
+class AxisOptionImg2Img(AxisOption):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.is_img2img = True
+
+class AxisOptionTxt2Img(AxisOption):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.is_img2img = False
+
+
+axis_options = [
+ AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
+ AxisOption("Seed", int, apply_field("seed")),
+ AxisOption("Var. seed", int, apply_field("subseed")),
+ AxisOption("Var. strength", float, apply_field("subseed_strength")),
+ AxisOption("Steps", int, apply_field("steps")),
+ AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
+ AxisOption("CFG Scale", float, apply_field("cfg_scale")),
+ AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
+ AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
+ AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
+ AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
+ AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
+ AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
+ AxisOption("Sigma Churn", float, apply_field("s_churn")),
+ AxisOption("Sigma min", float, apply_field("s_tmin")),
+ AxisOption("Sigma max", float, apply_field("s_tmax")),
+ AxisOption("Sigma noise", float, apply_field("s_noise")),
+ AxisOption("Eta", float, apply_field("eta")),
+ AxisOption("Clip skip", int, apply_clip_skip),
+ AxisOption("Denoising", float, apply_field("denoising_strength")),
+ AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
+ AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
+ AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
+ AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
+]
+
+
+def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size):
+ hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
+ ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
+ title_texts = [[images.GridAnnotation(z)] for z in z_labels]
+
+ # Temporary list of all the images that are generated to be populated into the grid.
+ # Will be filled with empty images for any individual step that fails to process properly
+ image_cache = [None] * (len(xs) * len(ys) * len(zs))
+
+ processed_result = None
+ cell_mode = "P"
+ cell_size = (1, 1)
+
+ state.job_count = len(xs) * len(ys) * len(zs) * p.n_iter
+
+ def process_cell(x, y, z, ix, iy, iz):
+ nonlocal image_cache, processed_result, cell_mode, cell_size
+
+ def index(ix, iy, iz):
+ return ix + iy * len(xs) + iz * len(xs) * len(ys)
+
+ state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}"
+
+ processed: Processed = cell(x, y, z)
+
+ try:
+ # this dereference will throw an exception if the image was not processed
+ # (this happens in cases such as if the user stops the process from the UI)
+ processed_image = processed.images[0]
+
+ if processed_result is None:
+ # Use our first valid processed result as a template container to hold our full results
+ processed_result = copy(processed)
+ cell_mode = processed_image.mode
+ cell_size = processed_image.size
+ processed_result.images = [Image.new(cell_mode, cell_size)]
+ processed_result.all_prompts = [processed.prompt]
+ processed_result.all_seeds = [processed.seed]
+ processed_result.infotexts = [processed.infotexts[0]]
+
+ image_cache[index(ix, iy, iz)] = processed_image
+ if include_lone_images:
+ processed_result.images.append(processed_image)
+ processed_result.all_prompts.append(processed.prompt)
+ processed_result.all_seeds.append(processed.seed)
+ processed_result.infotexts.append(processed.infotexts[0])
+ except:
+ image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size)
+
+ if first_axes_processed == 'x':
+ for ix, x in enumerate(xs):
+ if second_axes_processed == 'y':
+ for iy, y in enumerate(ys):
+ for iz, z in enumerate(zs):
+ process_cell(x, y, z, ix, iy, iz)
+ else:
+ for iz, z in enumerate(zs):
+ for iy, y in enumerate(ys):
+ process_cell(x, y, z, ix, iy, iz)
+ elif first_axes_processed == 'y':
+ for iy, y in enumerate(ys):
+ if second_axes_processed == 'x':
+ for ix, x in enumerate(xs):
+ for iz, z in enumerate(zs):
+ process_cell(x, y, z, ix, iy, iz)
+ else:
+ for iz, z in enumerate(zs):
+ for ix, x in enumerate(xs):
+ process_cell(x, y, z, ix, iy, iz)
+ elif first_axes_processed == 'z':
+ for iz, z in enumerate(zs):
+ if second_axes_processed == 'x':
+ for ix, x in enumerate(xs):
+ for iy, y in enumerate(ys):
+ process_cell(x, y, z, ix, iy, iz)
+ else:
+ for iy, y in enumerate(ys):
+ for ix, x in enumerate(xs):
+ process_cell(x, y, z, ix, iy, iz)
+
+ if not processed_result:
+ print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
+ return Processed(p, [])
+
+ sub_grids = [None] * len(zs)
+ for i in range(len(zs)):
+ start_index = i * len(xs) * len(ys)
+ end_index = start_index + len(xs) * len(ys)
+ grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
+ if draw_legend:
+ grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size)
+ sub_grids[i] = grid
+ if include_sub_grids and len(zs) > 1:
+ processed_result.images.insert(i+1, grid)
+
+ sub_grid_size = sub_grids[0].size
+ z_grid = images.image_grid(sub_grids, rows=1)
+ if draw_legend:
+ z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
+ processed_result.images[0] = z_grid
+
+ return processed_result, sub_grids
+
+
+class SharedSettingsStackHelper(object):
+ def __enter__(self):
+ self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
+ self.vae = opts.sd_vae
+
+ def __exit__(self, exc_type, exc_value, tb):
+ opts.data["sd_vae"] = self.vae
+ modules.sd_models.reload_model_weights()
+ modules.sd_vae.reload_vae_weights()
+
+ opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
+
+
+re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
+re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
+
+re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
+re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "X/Y/Z plot"
+
+ def ui(self, is_img2img):
+ self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]
+
+ with gr.Row():
+ with gr.Column(scale=19):
+ with gr.Row():
+ x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
+ x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
+ fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
+
+ with gr.Row():
+ y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
+ y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
+ fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
+
+ with gr.Row():
+ z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
+ z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
+ fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
+
+ with gr.Row(variant="compact", elem_id="axis_options"):
+ with gr.Column():
+ draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
+ no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
+ with gr.Column():
+ include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
+ include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
+ with gr.Column():
+ margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
+
+ with gr.Row(variant="compact", elem_id="swap_axes"):
+ swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
+ swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
+ swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
+
+ def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values):
+ return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values
+
+ xy_swap_args = [x_type, x_values, y_type, y_values]
+ swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
+ yz_swap_args = [y_type, y_values, z_type, z_values]
+ swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
+ xz_swap_args = [x_type, x_values, z_type, z_values]
+ swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
+
+ def fill(x_type):
+ axis = self.current_axis_options[x_type]
+ return ", ".join(axis.choices()) if axis.choices else gr.update()
+
+ fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
+ fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
+ fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values])
+
+ def select_axis(x_type):
+ return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
+
+ x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
+ y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
+ z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button])
+
+ self.infotext_fields = (
+ (x_type, "X Type"),
+ (x_values, "X Values"),
+ (y_type, "Y Type"),
+ (y_values, "Y Values"),
+ (z_type, "Z Type"),
+ (z_values, "Z Values"),
+ )
+
+ return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
+
+ def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
+ if not no_fixed_seeds:
+ modules.processing.fix_seed(p)
+
+ if not opts.return_grid:
+ p.batch_size = 1
+
+ def process_axis(opt, vals):
+ if opt.label == 'Nothing':
+ return [0]
+
+ valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
+
+ if opt.type == int:
+ valslist_ext = []
+
+ for val in valslist:
+ m = re_range.fullmatch(val)
+ mc = re_range_count.fullmatch(val)
+ if m is not None:
+ start = int(m.group(1))
+ end = int(m.group(2))+1
+ step = int(m.group(3)) if m.group(3) is not None else 1
+
+ valslist_ext += list(range(start, end, step))
+ elif mc is not None:
+ start = int(mc.group(1))
+ end = int(mc.group(2))
+ num = int(mc.group(3)) if mc.group(3) is not None else 1
+
+ valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
+ else:
+ valslist_ext.append(val)
+
+ valslist = valslist_ext
+ elif opt.type == float:
+ valslist_ext = []
+
+ for val in valslist:
+ m = re_range_float.fullmatch(val)
+ mc = re_range_count_float.fullmatch(val)
+ if m is not None:
+ start = float(m.group(1))
+ end = float(m.group(2))
+ step = float(m.group(3)) if m.group(3) is not None else 1
+
+ valslist_ext += np.arange(start, end + step, step).tolist()
+ elif mc is not None:
+ start = float(mc.group(1))
+ end = float(mc.group(2))
+ num = int(mc.group(3)) if mc.group(3) is not None else 1
+
+ valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
+ else:
+ valslist_ext.append(val)
+
+ valslist = valslist_ext
+ elif opt.type == str_permutations:
+ valslist = list(permutations(valslist))
+
+ valslist = [opt.type(x) for x in valslist]
+
+ # Confirm options are valid before starting
+ if opt.confirm:
+ opt.confirm(p, valslist)
+
+ return valslist
+
+ x_opt = self.current_axis_options[x_type]
+ xs = process_axis(x_opt, x_values)
+
+ y_opt = self.current_axis_options[y_type]
+ ys = process_axis(y_opt, y_values)
+
+ z_opt = self.current_axis_options[z_type]
+ zs = process_axis(z_opt, z_values)
+
+ def fix_axis_seeds(axis_opt, axis_list):
+ if axis_opt.label in ['Seed', 'Var. seed']:
+ return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
+ else:
+ return axis_list
+
+ if not no_fixed_seeds:
+ xs = fix_axis_seeds(x_opt, xs)
+ ys = fix_axis_seeds(y_opt, ys)
+ zs = fix_axis_seeds(z_opt, zs)
+
+ if x_opt.label == 'Steps':
+ total_steps = sum(xs) * len(ys) * len(zs)
+ elif y_opt.label == 'Steps':
+ total_steps = sum(ys) * len(xs) * len(zs)
+ elif z_opt.label == 'Steps':
+ total_steps = sum(zs) * len(xs) * len(ys)
+ else:
+ total_steps = p.steps * len(xs) * len(ys) * len(zs)
+
+ if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
+ if x_opt.label == "Hires steps":
+ total_steps += sum(xs) * len(ys) * len(zs)
+ elif y_opt.label == "Hires steps":
+ total_steps += sum(ys) * len(xs) * len(zs)
+ elif z_opt.label == "Hires steps":
+ total_steps += sum(zs) * len(xs) * len(ys)
+ elif p.hr_second_pass_steps:
+ total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs)
+ else:
+ total_steps *= 2
+
+ total_steps *= p.n_iter
+
+ image_cell_count = p.n_iter * p.batch_size
+ cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""
+ plural_s = 's' if len(zs) > 1 else ''
+ print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")
+ shared.total_tqdm.updateTotal(total_steps)
+
+ grid_infotext = [None]
+
+ state.xyz_plot_x = AxisInfo(x_opt, xs)
+ state.xyz_plot_y = AxisInfo(y_opt, ys)
+ state.xyz_plot_z = AxisInfo(z_opt, zs)
+
+ # If one of the axes is very slow to change between (like SD model
+ # checkpoint), then make sure it is in the outer iteration of the nested
+ # `for` loop.
+ first_axes_processed = 'x'
+ second_axes_processed = 'y'
+ if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
+ first_axes_processed = 'x'
+ if y_opt.cost > z_opt.cost:
+ second_axes_processed = 'y'
+ else:
+ second_axes_processed = 'z'
+ elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost:
+ first_axes_processed = 'y'
+ if x_opt.cost > z_opt.cost:
+ second_axes_processed = 'x'
+ else:
+ second_axes_processed = 'z'
+ elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost:
+ first_axes_processed = 'z'
+ if x_opt.cost > y_opt.cost:
+ second_axes_processed = 'x'
+ else:
+ second_axes_processed = 'y'
+
+ def cell(x, y, z):
+ if shared.state.interrupted:
+ return Processed(p, [], p.seed, "")
+
+ pc = copy(p)
+ pc.styles = pc.styles[:]
+ x_opt.apply(pc, x, xs)
+ y_opt.apply(pc, y, ys)
+ z_opt.apply(pc, z, zs)
+
+ res = process_images(pc)
+
+ if grid_infotext[0] is None:
+ pc.extra_generation_params = copy(pc.extra_generation_params)
+ pc.extra_generation_params['Script'] = self.title()
+
+ if x_opt.label != 'Nothing':
+ pc.extra_generation_params["X Type"] = x_opt.label
+ pc.extra_generation_params["X Values"] = x_values
+ if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
+ pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
+
+ if y_opt.label != 'Nothing':
+ pc.extra_generation_params["Y Type"] = y_opt.label
+ pc.extra_generation_params["Y Values"] = y_values
+ if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
+ pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
+
+ if z_opt.label != 'Nothing':
+ pc.extra_generation_params["Z Type"] = z_opt.label
+ pc.extra_generation_params["Z Values"] = z_values
+ if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
+ pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs])
+
+ grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
+
+ return res
+
+ with SharedSettingsStackHelper():
+ processed, sub_grids = draw_xyz_grid(
+ p,
+ xs=xs,
+ ys=ys,
+ zs=zs,
+ x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
+ y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
+ z_labels=[z_opt.format_value(p, z_opt, z) for z in zs],
+ cell=cell,
+ draw_legend=draw_legend,
+ include_lone_images=include_lone_images,
+ include_sub_grids=include_sub_grids,
+ first_axes_processed=first_axes_processed,
+ second_axes_processed=second_axes_processed,
+ margin_size=margin_size
+ )
+
+ if opts.grid_save and len(sub_grids) > 1:
+ for sub_grid in sub_grids:
+ images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+
+ if opts.grid_save:
+ images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+
+ return processed
diff --git a/webui.sh b/webui.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8cdad22d310fed20f229b09d7a3160aeb1731a85
--- /dev/null
+++ b/webui.sh
@@ -0,0 +1,186 @@
+#!/usr/bin/env bash
+#################################################
+# Please do not make any changes to this file, #
+# change the variables in webui-user.sh instead #
+#################################################
+
+# If run from macOS, load defaults from webui-macos-env.sh
+if [[ "$OSTYPE" == "darwin"* ]]; then
+ if [[ -f webui-macos-env.sh ]]
+ then
+ source ./webui-macos-env.sh
+ fi
+fi
+
+# Read variables from webui-user.sh
+# shellcheck source=/dev/null
+if [[ -f webui-user.sh ]]
+then
+ source ./webui-user.sh
+fi
+
+# Set defaults
+# Install directory without trailing slash
+if [[ -z "${install_dir}" ]]
+then
+ install_dir="/home/$(whoami)"
+fi
+
+# Name of the subdirectory (defaults to stable-diffusion-webui)
+if [[ -z "${clone_dir}" ]]
+then
+ clone_dir="stable-diffusion-webui"
+fi
+
+# python3 executable
+if [[ -z "${python_cmd}" ]]
+then
+ python_cmd="python3"
+fi
+
+# git executable
+if [[ -z "${GIT}" ]]
+then
+ export GIT="git"
+fi
+
+# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
+if [[ -z "${venv_dir}" ]]
+then
+ venv_dir="venv"
+fi
+
+if [[ -z "${LAUNCH_SCRIPT}" ]]
+then
+ LAUNCH_SCRIPT="launch.py"
+fi
+
+# this script cannot be run as root by default
+can_run_as_root=0
+
+# read any command line flags to the webui.sh script
+while getopts "f" flag > /dev/null 2>&1
+do
+ case ${flag} in
+ f) can_run_as_root=1;;
+ *) break;;
+ esac
+done
+
+# Disable sentry logging
+export ERROR_REPORTING=FALSE
+
+# Do not reinstall existing pip packages on Debian/Ubuntu
+export PIP_IGNORE_INSTALLED=0
+
+# Pretty print
+delimiter="################################################################"
+
+printf "\n%s\n" "${delimiter}"
+printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n"
+printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m"
+printf "\n%s\n" "${delimiter}"
+
+# Do not run as root
+if [[ $(id -u) -eq 0 && can_run_as_root -eq 0 ]]
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+else
+ printf "\n%s\n" "${delimiter}"
+ printf "Running on \e[1m\e[32m%s\e[0m user" "$(whoami)"
+ printf "\n%s\n" "${delimiter}"
+fi
+
+if [[ -d .git ]]
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "Repo already cloned, using it as install directory"
+ printf "\n%s\n" "${delimiter}"
+ install_dir="${PWD}/../"
+ clone_dir="${PWD##*/}"
+fi
+
+# Check prerequisites
+gpu_info=$(lspci 2>/dev/null | grep VGA)
+case "$gpu_info" in
+ *"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
+ ;;
+ *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
+ printf "\n%s\n" "${delimiter}"
+ printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
+ printf "\n%s\n" "${delimiter}"
+ ;;
+ *)
+ ;;
+esac
+if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
+then
+ export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
+fi
+
+for preq in "${GIT}" "${python_cmd}"
+do
+ if ! hash "${preq}" &>/dev/null
+ then
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: %s is not installed, aborting...\e[0m" "${preq}"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+ fi
+done
+
+if ! "${python_cmd}" -c "import venv" &>/dev/null
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+fi
+
+cd "${install_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/, aborting...\e[0m" "${install_dir}"; exit 1; }
+if [[ -d "${clone_dir}" ]]
+then
+ cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
+else
+ printf "\n%s\n" "${delimiter}"
+ printf "Clone stable-diffusion-webui"
+ printf "\n%s\n" "${delimiter}"
+ "${GIT}" clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git "${clone_dir}"
+ cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
+fi
+
+printf "\n%s\n" "${delimiter}"
+printf "Create and activate python venv"
+printf "\n%s\n" "${delimiter}"
+cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
+if [[ ! -d "${venv_dir}" ]]
+then
+ "${python_cmd}" -m venv "${venv_dir}"
+ first_launch=1
+fi
+# shellcheck source=/dev/null
+if [[ -f "${venv_dir}"/bin/activate ]]
+then
+ source "${venv_dir}"/bin/activate
+else
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+fi
+
+if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "Accelerating launch.py..."
+ printf "\n%s\n" "${delimiter}"
+ exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
+else
+ printf "\n%s\n" "${delimiter}"
+ printf "Launching launch.py..."
+ printf "\n%s\n" "${delimiter}"
+ exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
+fi
diff --git a/webui_v2.bat b/webui_v2.bat
new file mode 100644
index 0000000000000000000000000000000000000000..2b7fad92f440a60160c009803795ce2d3cd2ef15
--- /dev/null
+++ b/webui_v2.bat
@@ -0,0 +1,85 @@
+@echo off
+
+if not defined PYTHON (set PYTHON=python)
+if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
+
+
+set ERROR_REPORTING=FALSE
+
+mkdir tmp 2>NUL
+
+%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :check_pip
+echo Couldn't launch python
+goto :show_stdout_stderr
+
+:check_pip
+%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :start_venv
+if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
+%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :start_venv
+echo Couldn't install pip
+goto :show_stdout_stderr
+
+:start_venv
+if ["%VENV_DIR%"] == ["-"] goto :skip_venv
+if ["%SKIP_VENV%"] == ["1"] goto :skip_venv
+
+dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :activate_venv
+
+for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
+echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
+%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :activate_venv
+echo Unable to create venv in directory "%VENV_DIR%"
+goto :show_stdout_stderr
+
+:activate_venv
+set PYTHON="%VENV_DIR%\Scripts\Python.exe"
+echo venv %PYTHON%
+
+:skip_venv
+if [%ACCELERATE%] == ["True"] goto :accelerate
+goto :launch
+
+:accelerate
+echo Checking for accelerate
+set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe"
+if EXIST %ACCELERATE% goto :accelerate_launch
+
+:launch
+%PYTHON% webui_v2.py %*
+pause
+exit /b
+
+:accelerate_launch
+echo Accelerating
+%ACCELERATE% launch --num_cpu_threads_per_process=6 webui_v2.py
+pause
+exit /b
+
+:show_stdout_stderr
+
+echo.
+echo exit code: %errorlevel%
+
+for /f %%i in ("tmp\stdout.txt") do set size=%%~zi
+if %size% equ 0 goto :show_stderr
+echo.
+echo stdout:
+type tmp\stdout.txt
+
+:show_stderr
+for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
+if %size% equ 0 goto :show_stderr
+echo.
+echo stderr:
+type tmp\stderr.txt
+
+:endofscript
+
+echo.
+echo Launch unsuccessful. Exiting.
+pause