import os from copy import copy from enum import Enum from typing import Tuple, List from modules import img2img, processing, shared, script_callbacks from scripts import external_code class BatchHijack: def __init__(self): self.is_batch = False self.batch_index = 0 self.batch_size = 1 self.init_seed = None self.init_subseed = None self.process_batch_callbacks = [self.on_process_batch] self.process_batch_each_callbacks = [] self.postprocess_batch_each_callbacks = [self.on_postprocess_batch_each] self.postprocess_batch_callbacks = [self.on_postprocess_batch] def img2img_process_batch_hijack(self, p, *args, **kwargs): cn_is_batch, batches, output_dir, _ = get_cn_batches(p) if not cn_is_batch: return getattr(img2img, '__controlnet_original_process_batch')(p, *args, **kwargs) self.dispatch_callbacks(self.process_batch_callbacks, p, batches, output_dir) try: return getattr(img2img, '__controlnet_original_process_batch')(p, *args, **kwargs) finally: self.dispatch_callbacks(self.postprocess_batch_callbacks, p) def processing_process_images_hijack(self, p, *args, **kwargs): if self.is_batch: # we are in img2img batch tab, do a single batch iteration return self.process_images_cn_batch(p, *args, **kwargs) cn_is_batch, batches, output_dir, input_file_names = get_cn_batches(p) if not cn_is_batch: # we are not in batch mode, fallback to original function return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs) output_images = [] try: self.dispatch_callbacks(self.process_batch_callbacks, p, batches, output_dir) for batch_i in range(self.batch_size): processed = self.process_images_cn_batch(p, *args, **kwargs) if shared.opts.data.get('controlnet_show_batch_images_in_ui', False): output_images.extend(processed.images[processed.index_of_first_image:]) if output_dir: self.save_images(output_dir, input_file_names[batch_i], processed.images[processed.index_of_first_image:]) if shared.state.interrupted: break finally: self.dispatch_callbacks(self.postprocess_batch_callbacks, p) if output_images: processed.images = output_images else: processed = processing.Processed(p, [], p.seed) return processed def process_images_cn_batch(self, p, *args, **kwargs): self.dispatch_callbacks(self.process_batch_each_callbacks, p) old_detectmap_output = shared.opts.data.get('control_net_no_detectmap', False) try: shared.opts.data.update({'control_net_no_detectmap': True}) processed = getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs) finally: shared.opts.data.update({'control_net_no_detectmap': old_detectmap_output}) self.dispatch_callbacks(self.postprocess_batch_each_callbacks, p, processed) # do not go past control net batch size if self.batch_index >= self.batch_size: shared.state.interrupted = True return processed def save_images(self, output_dir, init_image_path, output_images): os.makedirs(output_dir, exist_ok=True) for n, processed_image in enumerate(output_images): filename = os.path.basename(init_image_path) if n > 0: left, right = os.path.splitext(filename) filename = f"{left}-{n}{right}" if processed_image.mode == 'RGBA': processed_image = processed_image.convert("RGB") processed_image.save(os.path.join(output_dir, filename)) def do_hijack(self): script_callbacks.on_script_unloaded(self.undo_hijack) hijack_function( module=img2img, name='process_batch', new_name='__controlnet_original_process_batch', new_value=self.img2img_process_batch_hijack, ) hijack_function( module=processing, name='process_images_inner', new_name='__controlnet_original_process_images_inner', new_value=self.processing_process_images_hijack ) def undo_hijack(self): unhijack_function( module=img2img, name='process_batch', new_name='__controlnet_original_process_batch', ) unhijack_function( module=processing, name='process_images_inner', new_name='__controlnet_original_process_images_inner', ) def adjust_job_count(self, p): if shared.state.job_count == -1: shared.state.job_count = p.n_iter shared.state.job_count *= self.batch_size def on_process_batch(self, p, batches, output_dir, *args): print('controlnet batch mode') self.is_batch = True self.batch_index = 0 self.batch_size = len(batches) processing.fix_seed(p) if shared.opts.data.get('controlnet_increment_seed_during_batch', False): self.init_seed = p.seed self.init_subseed = p.subseed self.adjust_job_count(p) p.do_not_save_grid = True p.do_not_save_samples = bool(output_dir) def on_postprocess_batch_each(self, p, *args): self.batch_index += 1 if shared.opts.data.get('controlnet_increment_seed_during_batch', False): p.seed = p.seed + len(p.all_prompts) p.subseed = p.subseed + len(p.all_prompts) def on_postprocess_batch(self, p, *args): self.is_batch = False self.batch_index = 0 self.batch_size = 1 if shared.opts.data.get('controlnet_increment_seed_during_batch', False): p.seed = self.init_seed p.all_seeds = [self.init_seed] p.subseed = self.init_subseed p.all_subseeds = [self.init_subseed] def dispatch_callbacks(self, callbacks, *args): for callback in callbacks: callback(*args) def hijack_function(module, name, new_name, new_value): # restore original function in case of reload unhijack_function(module=module, name=name, new_name=new_name) setattr(module, new_name, getattr(module, name)) setattr(module, name, new_value) def unhijack_function(module, name, new_name): if hasattr(module, new_name): setattr(module, name, getattr(module, new_name)) delattr(module, new_name) class InputMode(Enum): SIMPLE = "simple" BATCH = "batch" def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[List[str]], str, List[str]]: units = external_code.get_all_units_in_processing(p) units = [copy(unit) for unit in units if getattr(unit, 'enabled', False)] any_unit_is_batch = False output_dir = '' input_file_names = [] for unit in units: if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH: any_unit_is_batch = True output_dir = getattr(unit, 'output_dir', '') if isinstance(unit.batch_images, str): unit.batch_images = shared.listfiles(unit.batch_images) input_file_names = unit.batch_images if any_unit_is_batch: cn_batch_size = min(len(getattr(unit, 'batch_images', [])) for unit in units if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH) else: cn_batch_size = 1 batches = [[] for _ in range(cn_batch_size)] for i in range(cn_batch_size): for unit in units: if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.SIMPLE: batches[i].append(unit.image) else: batches[i].append(unit.batch_images[i]) return any_unit_is_batch, batches, output_dir, input_file_names instance = BatchHijack()