Please test on small images before actual upscale. Default params require denoise <= 0.6
') + with gr.Row(variant='compact'): + noise_inverse_retouch = gr.Slider(minimum=1, maximum=100, step=0.1, label='Retouch', value=1, elem_id=uid('noise-inverse-retouch')) + noise_inverse_renoise_strength = gr.Slider(minimum=0, maximum=2, step=0.01, label='Renoise strength', value=1, elem_id=uid('noise-inverse-renoise-strength')) + noise_inverse_renoise_kernel = gr.Slider(minimum=2, maximum=512, step=1, label='Renoise kernel size', value=64, elem_id=uid('noise-inverse-renoise-kernel')) + + # The control includes txt2img and img2img, we use t2i and i2i to distinguish them + with gr.Group(elem_id=f'MD-bbox-control-{tab}') as tab_bbox: + with gr.Accordion('Region Prompt Control', open=False): + with gr.Row(variant='compact'): + enable_bbox_control = gr.Checkbox(label='Enable Control', value=False, elem_id=uid('enable-bbox-control')) + draw_background = gr.Checkbox(label='Draw full canvas background', value=False, elem_id=uid('draw-background')) + causal_layers = gr.Checkbox(label='Causalize layers', value=False, visible=False, elem_id='MD-causal-layers') # NOTE: currently not used + + with gr.Row(variant='compact'): + create_button = gr.Button(value="Create txt2img canvas" if not is_img2img else "From img2img", elem_id='MD-create-canvas') + + bbox_controls: List[Component] = [] # control set for each bbox + with gr.Row(variant='compact'): + ref_image = gr.Image(label='Ref image (for conviently locate regions)', image_mode=None, elem_id=f'MD-bbox-ref-{tab}', interactive=True) + if not is_img2img: + # gradio has a serious bug: it cannot accept multiple inputs when you use both js and fn. + # to workaround this, we concat the inputs into a single string and parse it in js + def create_t2i_ref(string): + w, h = [int(x) for x in string.split('x')] + w = max(w, opt_f) + h = max(h, opt_f) + return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255 + create_button.click( + fn=create_t2i_ref, + inputs=overwrite_size, + outputs=ref_image, + _js='onCreateT2IRefClick', + show_progress=False) + else: + create_button.click(fn=None, outputs=ref_image, _js='onCreateI2IRefClick', show_progress=False) + + with gr.Row(variant='compact'): + cfg_name = gr.Textbox(label='Custom Config File', value='config.json', elem_id=uid('cfg-name')) + cfg_dump = gr.Button(value='💾 Save', variant='tool') + cfg_load = gr.Button(value='⚙️ Load', variant='tool') + + with gr.Row(variant='compact'): + cfg_tip = gr.HTML(value='', visible=False) + + for i in range(BBOX_MAX_NUM): + # Only when displaying & png generate info we use index i+1, in other cases we use i + with gr.Accordion(f'Region {i+1}', open=False, elem_id=f'MD-accordion-{tab}-{i}'): + with gr.Row(variant='compact'): + e = gr.Checkbox(label=f'Enable Region {i+1}', value=False, elem_id=f'MD-bbox-{tab}-{i}-enable') + e.change(fn=None, inputs=e, outputs=e, _js=f'e => onBoxEnableClick({is_t2i}, {i}, e)', show_progress=False) + + blend_mode = gr.Dropdown(label='Type', choices=[e.value for e in BlendMode], value=BlendMode.BACKGROUND.value, elem_id=f'MD-{tab}-{i}-blend-mode') + feather_ratio = gr.Slider(label='Feather', value=0.2, minimum=0, maximum=1, step=0.05, visible=False, elem_id=f'MD-{tab}-{i}-feather') + + blend_mode.change(fn=lambda x: gr_show(x==BlendMode.FOREGROUND.value), inputs=blend_mode, outputs=feather_ratio, show_progress=False) + + with gr.Row(variant='compact'): + x = gr.Slider(label='x', value=0.4, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-x') + y = gr.Slider(label='y', value=0.4, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-y') + + with gr.Row(variant='compact'): + w = gr.Slider(label='w', value=0.2, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-w') + h = gr.Slider(label='h', value=0.2, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-h') + + x.change(fn=None, inputs=x, outputs=x, _js=f'v => onBoxChange({is_t2i}, {i}, "x", v)', show_progress=False) + y.change(fn=None, inputs=y, outputs=y, _js=f'v => onBoxChange({is_t2i}, {i}, "y", v)', show_progress=False) + w.change(fn=None, inputs=w, outputs=w, _js=f'v => onBoxChange({is_t2i}, {i}, "w", v)', show_progress=False) + h.change(fn=None, inputs=h, outputs=h, _js=f'v => onBoxChange({is_t2i}, {i}, "h", v)', show_progress=False) + + prompt = gr.Text(show_label=False, placeholder=f'Prompt, will append to your {tab} prompt', max_lines=2, elem_id=f'MD-{tab}-{i}-prompt') + neg_prompt = gr.Text(show_label=False, placeholder='Negative Prompt, will also be appended', max_lines=1, elem_id=f'MD-{tab}-{i}-neg-prompt') + with gr.Row(variant='compact'): + seed = gr.Number(label='Seed', value=-1, visible=True, elem_id=f'MD-{tab}-{i}-seed') + random_seed = gr.Button(value='🎲', variant='tool', elem_id=f'MD-{tab}-{i}-random_seed') + reuse_seed = gr.Button(value='♻️', variant='tool', elem_id=f'MD-{tab}-{i}-reuse_seed') + random_seed.click(fn=lambda: -1, outputs=seed, show_progress=False) + reuse_seed.click(fn=None, inputs=seed, outputs=seed, _js=f'e => getSeedInfo({is_t2i}, {i+1}, e)', show_progress=False) + + control = [e, x, y, w, h, prompt, neg_prompt, blend_mode, feather_ratio, seed] + assert len(control) == NUM_BBOX_PARAMS + bbox_controls.extend(control) + + # NOTE: dynamically hard coded!! + load_regions_js = ''' + function onBoxChangeAll(ref_image, cfg_name, ...args) { + const is_t2i = %s; + const n_bbox = %d; + const n_ctrl = %d; + for (let i=0; iPlease test on small images before actual upscale. Default params require denoise <= 0.6
') + with gr.Row(variant='compact'): + noise_inverse_retouch = gr.Slider(minimum=1, maximum=100, step=0.1, label='Retouch', value=1, elem_id=uid('noise-inverse-retouch')) + noise_inverse_renoise_strength = gr.Slider(minimum=0, maximum=2, step=0.01, label='Renoise strength', value=1, elem_id=uid('noise-inverse-renoise-strength')) + noise_inverse_renoise_kernel = gr.Slider(minimum=2, maximum=512, step=1, label='Renoise kernel size', value=64, elem_id=uid('noise-inverse-renoise-kernel')) + + # The control includes txt2img and img2img, we use t2i and i2i to distinguish them + + return [ + enabled, method, + keep_input_size, + window_size, overlap, batch_size, + scale_factor, + noise_inverse, noise_inverse_steps, noise_inverse_retouch, noise_inverse_renoise_strength, noise_inverse_renoise_kernel, + control_tensor_cpu, + random_jitter, + c1,c2,c3,gaussian_filter,strength,sigma,batch_size_g,mixture_mode + ] + + + def process(self, p: Processing, + enabled: bool, method: str, + keep_input_size: bool, + window_size:int, overlap: int, tile_batch_size: int, + scale_factor: float, + noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch: float, noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int, + control_tensor_cpu: bool, + random_jitter:bool, + c1,c2,c3,gaussian_filter,strength,sigma,batch_size_g,mixture_mode + ): + + # unhijack & unhook, in case it broke at last time + self.reset() + p.mixture = mixture_mode + if not mixture_mode: + sigma = sigma/2 + if not enabled: return + + ''' upscale ''' + # store canvas size settings + if hasattr(p, "init_images"): + p.init_images_original_md = [img.copy() for img in p.init_images] + p.width_original_md = p.width + p.height_original_md = p.height + p.current_scale_num = 1 + p.gaussian_filter = gaussian_filter + p.scale_factor = int(scale_factor) + + is_img2img = hasattr(p, "init_images") and len(p.init_images) > 0 + if is_img2img: + init_img = p.init_images[0] + init_img = images.flatten(init_img, opts.img2img_background_color) + image = init_img + if keep_input_size: + p.width = image.width + p.height = image.height + p.width_original_md = p.width + p.height_original_md = p.height + else: #XXX:To adapt to noise inversion, we do not multiply the scale factor here + p.width = p.width_original_md + p.height = p.height_original_md + else: # txt2img + p.width = p.width_original_md + p.height = p.height_original_md + + if 'png info': + info = {} + p.extra_generation_params["Tiled Diffusion"] = info + + info['Method'] = method + info['Window Size'] = window_size + info['Tile Overlap'] = overlap + info['Tile batch size'] = tile_batch_size + info["Global batch size"] = batch_size_g + + if is_img2img: + info['Upscale factor'] = scale_factor + if keep_input_size: + info['Keep input size'] = keep_input_size + if noise_inverse: + info['NoiseInv'] = noise_inverse + info['NoiseInv Steps'] = noise_inverse_steps + info['NoiseInv Retouch'] = noise_inverse_retouch + info['NoiseInv Renoise strength'] = noise_inverse_renoise_strength + info['NoiseInv Kernel size'] = noise_inverse_renoise_kernel + + ''' ControlNet hackin ''' + try: + from scripts.cldm import ControlNet + + for script in p.scripts.scripts + p.scripts.alwayson_scripts: + if hasattr(script, "latest_network") and script.title().lower() == "controlnet": + self.controlnet_script = script + print("[Demo Fusion] ControlNet found, support is enabled.") + break + except ImportError: + pass + + ''' StableSR hackin ''' + for script in p.scripts.scripts: + if hasattr(script, "stablesr_model") and script.title().lower() == "stablesr": + if script.stablesr_model is not None: + self.stablesr_script = script + print("[Demo Fusion] StableSR found, support is enabled.") + break + + ''' hijack inner APIs, see unhijack in reset() ''' + Script.create_sampler_original_md = sd_samplers.create_sampler + + sd_samplers.create_sampler = lambda name, model: self.create_sampler_hijack( + name, model, p, Method_2(method), control_tensor_cpu,window_size, noise_inverse, noise_inverse_steps, noise_inverse_retouch, + noise_inverse_renoise_strength, noise_inverse_renoise_kernel, overlap, tile_batch_size,random_jitter,batch_size_g + ) + + + p.sample = lambda conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts: self.sample_hijack( + conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts,p, is_img2img, + window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma,batch_size_g) + + processing.create_infotext_ori = processing.create_infotext + + p.width_list = [p.height] + p.height_list = [p.height] + + processing.create_infotext = create_infotext_hijack + ## end + + + def postprocess_batch(self, p: Processing, enabled, *args, **kwargs): + if not enabled: return + + if self.delegate is not None: self.delegate.reset_controlnet_tensors() + + def postprocess_batch_list(self, p, pp, enabled, *args, **kwargs): + if not enabled: return + for idx,image in enumerate(pp.images): + idx_b = idx//p.batch_size + pp.images[idx] = image[:,:image.shape[1]//(p.scale_factor)*(idx_b+1),:image.shape[2]//(p.scale_factor)*(idx_b+1)] + p.seeds = [item for _ in range(p.scale_factor) for item in p.seeds] + p.prompts = [item for _ in range(p.scale_factor) for item in p.prompts] + p.all_negative_prompts = [item for _ in range(p.scale_factor) for item in p.all_negative_prompts] + p.negative_prompts = [item for _ in range(p.scale_factor) for item in p.negative_prompts] + if p.color_corrections != None: + p.color_corrections = [item for _ in range(p.scale_factor) for item in p.color_corrections] + p.width_list = [item*(idx+1) for idx in range(p.scale_factor) for item in [p.width for _ in range(p.batch_size)]] + p.height_list = [item*(idx+1) for idx in range(p.scale_factor) for item in [p.height for _ in range(p.batch_size)]] + return + + def postprocess(self, p: Processing, processed, enabled, *args): + if not enabled: return + # unhijack & unhook + self.reset() + + # restore canvas size settings + if hasattr(p, 'init_images') and hasattr(p, 'init_images_original_md'): + p.init_images.clear() # NOTE: do NOT change the list object, compatible with shallow copy of XYZ-plot + p.init_images.extend(p.init_images_original_md) + del p.init_images_original_md + p.width = p.width_original_md ; del p.width_original_md + p.height = p.height_original_md ; del p.height_original_md + + # clean up noise inverse latent for folder-based processing + if hasattr(p, 'noise_inverse_latent'): + del p.noise_inverse_latent + + ''' ↓↓↓ inner API hijack ↓↓↓ ''' + @torch.no_grad() + def sample_hijack(self, conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts,p,image_ori,window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma,batch_size_g): + ################################################## Phase Initialization ###################################################### + + if not image_ori: + p.current_step = 0 + p.denoising_strength = strength + # p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) #NOTE:Wrong but very useful. If corrected, please replace with the content with the following lines + # latents = p.rng.next() + + p.sampler = Script.create_sampler_original_md(p.sampler_name, p.sd_model) #scale + x = p.rng.next() + print("### Phase 1 Denoising ###") + latents = p.sampler.sample(p, x, conditioning, unconditional_conditioning, image_conditioning=p.txt2img_image_conditioning(x)) + latents_ = F.pad(latents, (0, latents.shape[3]*(p.scale_factor-1), 0, latents.shape[2]*(p.scale_factor-1))) + res = latents_ + del x + p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) + starting_scale = 2 + else: # img2img + print("### Encoding Real Image ###") + latents = p.init_latent + starting_scale = 1 + + + anchor_mean = latents.mean() + anchor_std = latents.std() + + devices.torch_gc() + + ####################################################### Phase Upscaling ##################################################### + p.cosine_scale_1 = c1 + p.cosine_scale_2 = c2 + p.cosine_scale_3 = c3 + self.delegate.sig = sigma + p.latents = latents + for current_scale_num in range(starting_scale, p.scale_factor+1): + p.current_scale_num = current_scale_num + print("### Phase {} Denoising ###".format(current_scale_num)) + p.current_height = p.height_original_md * current_scale_num + p.current_width = p.width_original_md * current_scale_num + + + p.latents = F.interpolate(p.latents, size=(int(p.current_height / opt_f), int(p.current_width / opt_f)), mode='bicubic') + p.rng = rng.ImageRNG(p.latents.shape[1:], p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) + + + self.delegate.w = int(p.current_width / opt_f) + self.delegate.h = int(p.current_height / opt_f) + self.delegate.get_views(overlap, tile_batch_size,batch_size_g) + + info = ', '.join([ + # f"{method.value} hooked into {name!r} sampler", + f"Tile size: {self.delegate.window_size}", + f"Tile count: {self.delegate.num_tiles}", + f"Batch size: {self.delegate.tile_bs}", + f"Tile batches: {len(self.delegate.batched_bboxes)}", + f"Global batch size: {self.delegate.global_tile_bs}", + f"Global batches: {len(self.delegate.global_batched_bboxes)}", + ]) + + print(info) + + noise = p.rng.next() + if hasattr(p,'initial_noise_multiplier'): + if p.initial_noise_multiplier != 1.0: + p.extra_generation_params["Noise multiplier"] = p.initial_noise_multiplier + noise *= p.initial_noise_multiplier + else: + p.image_conditioning = p.txt2img_image_conditioning(noise) + + p.noise = noise + p.x = p.latents.clone() + p.current_step=0 + + p.latents = p.sampler.sample_img2img(p,p.latents, noise , conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) + if self.flag_noise_inverse: + self.delegate.sampler_raw.sample_img2img = self.delegate.sample_img2img_original + self.flag_noise_inverse = False + + p.latents = (p.latents - p.latents.mean()) / p.latents.std() * anchor_std + anchor_mean + latents_ = F.pad(p.latents, (0, p.latents.shape[3]//current_scale_num*(p.scale_factor-current_scale_num), 0, p.latents.shape[2]//current_scale_num*(p.scale_factor-current_scale_num))) + if current_scale_num==1: + res = latents_ + else: + res = torch.concatenate((res,latents_),axis=0) + + ######################################################################################################################################### + + return res + + @staticmethod + def callback_hijack(self_sampler,d,p): + p.current_step = d['i'] + + if self_sampler.stop_at is not None and p.current_step > self_sampler.stop_at: + raise InterruptedException + + state.sampling_step = p.current_step + shared.total_tqdm.update() + p.current_step += 1 + + + def create_sampler_hijack( + self, name: str, model: LatentDiffusion, p: Processing, method: Method_2, control_tensor_cpu:bool,window_size, noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch:float, + noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int, overlap:int, tile_batch_size:int, random_jitter:bool,batch_size_g:int + ): + if self.delegate is not None: + # samplers are stateless, we reuse it if possible + if self.delegate.sampler_name == name: + # before we reuse the sampler, we refresh the control tensor + # so that we are compatible with ControlNet batch processing + if self.controlnet_script: + self.delegate.prepare_controlnet_tensors(refresh=True) + return self.delegate.sampler_raw + else: + self.reset() + sd_samplers_common.Sampler.callback_ori = sd_samplers_common.Sampler.callback_state + sd_samplers_common.Sampler.callback_state = lambda self_sampler,d:Script.callback_hijack(self_sampler,d,p) + + self.flag_noise_inverse = hasattr(p, "init_images") and len(p.init_images) > 0 and noise_inverse + flag_noise_inverse = self.flag_noise_inverse + if flag_noise_inverse: + print('warn: noise inversion only supports the "Euler" sampler, switch to it sliently...') + name = 'Euler' + p.sampler_name = 'Euler' + if name is None: print('>> name is empty') + if model is None: print('>> model is empty') + sampler = Script.create_sampler_original_md(name, model) + if method ==Method_2.DEMO_FU: delegate_cls = DemoFusion + else: raise NotImplementedError(f"Method {method} not implemented.") + + delegate = delegate_cls(p, sampler) + delegate.window_size = min(min(window_size,p.width//8),p.height//8) + p.random_jitter = random_jitter + + if flag_noise_inverse: + get_cache_callback = self.noise_inverse_get_cache + set_cache_callback = lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, noise_inverse_steps, noise_inverse_retouch) + delegate.init_noise_inverse(noise_inverse_steps, noise_inverse_retouch, get_cache_callback, set_cache_callback, noise_inverse_renoise_strength, noise_inverse_renoise_kernel) + + # delegate.get_views(overlap,tile_batch_size,batch_size_g) + if self.controlnet_script: + delegate.init_controlnet(self.controlnet_script, control_tensor_cpu) + if self.stablesr_script: + delegate.init_stablesr(self.stablesr_script) + + # init everything done, perform sanity check & pre-computations + # hijack the behaviours + delegate.hook() + + self.delegate = delegate + + exts = [ + "ContrlNet" if self.controlnet_script else None, + "StableSR" if self.stablesr_script else None, + ] + ext_info = ', '.join([e for e in exts if e]) + if ext_info: ext_info = f' (ext: {ext_info})' + print(ext_info) + + return delegate.sampler_raw + + def create_random_tensors_hijack( + self, bbox_settings: Dict, region_info: Dict, + shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None, + ): + org_random_tensors = Script.create_random_tensors_original_md(shape, seeds, subseeds, subseed_strength, seed_resize_from_h, seed_resize_from_w, p) + height, width = shape[1], shape[2] + background_noise = torch.zeros_like(org_random_tensors) + background_noise_count = torch.zeros((1, 1, height, width), device=org_random_tensors.device) + foreground_noise = torch.zeros_like(org_random_tensors) + foreground_noise_count = torch.zeros((1, 1, height, width), device=org_random_tensors.device) + + for i, v in bbox_settings.items(): + seed = get_fixed_seed(v.seed) + x, y, w, h = v.x, v.y, v.w, v.h + # convert to pixel + x = int(x * width) + y = int(y * height) + w = math.ceil(w * width) + h = math.ceil(h * height) + # clamp + x = max(0, x) + y = max(0, y) + w = min(width - x, w) + h = min(height - y, h) + # create random tensor + torch.manual_seed(seed) + rand_tensor = torch.randn((1, org_random_tensors.shape[1], h, w), device=devices.cpu) + if BlendMode(v.blend_mode) == BlendMode.BACKGROUND: + background_noise [:, :, y:y+h, x:x+w] += rand_tensor.to(background_noise.device) + background_noise_count[:, :, y:y+h, x:x+w] += 1 + elif BlendMode(v.blend_mode) == BlendMode.FOREGROUND: + foreground_noise [:, :, y:y+h, x:x+w] += rand_tensor.to(foreground_noise.device) + foreground_noise_count[:, :, y:y+h, x:x+w] += 1 + else: + raise NotImplementedError + region_info['Region ' + str(i+1)]['seed'] = seed + + # average + background_noise = torch.where(background_noise_count > 1, background_noise / background_noise_count, background_noise) + foreground_noise = torch.where(foreground_noise_count > 1, foreground_noise / foreground_noise_count, foreground_noise) + # paste two layers to original random tensor + org_random_tensors = torch.where(background_noise_count > 0, background_noise, org_random_tensors) + org_random_tensors = torch.where(foreground_noise_count > 0, foreground_noise, org_random_tensors) + return org_random_tensors + + ''' ↓↓↓ helper methods ↓↓↓ ''' + + def dump_regions(self, cfg_name, *bbox_controls): + if not cfg_name: return gr_value(f'Config file name cannot be empty.', visible=True) + + bbox_settings = build_bbox_settings(bbox_controls) + data = {'bbox_controls': [v._asdict() for v in bbox_settings.values()]} + + if not os.path.exists(CFG_PATH): os.makedirs(CFG_PATH) + fp = os.path.join(CFG_PATH, cfg_name) + with open(fp, 'w', encoding='utf-8') as fh: + json.dump(data, fh, indent=2, ensure_ascii=False) + + return gr_value(f'Config saved to {fp}.', visible=True) + + def load_regions(self, ref_image, cfg_name, *bbox_controls): + if ref_image is None: + return [gr_value(v) for v in bbox_controls] + [gr_value(f'Please create or upload a ref image first.', visible=True)] + fp = os.path.join(CFG_PATH, cfg_name) + if not os.path.exists(fp): + return [gr_value(v) for v in bbox_controls] + [gr_value(f'Config {fp} not found.', visible=True)] + + try: + with open(fp, 'r', encoding='utf-8') as fh: + data = json.load(fh) + except Exception as e: + return [gr_value(v) for v in bbox_controls] + [gr_value(f'Failed to load config {fp}: {e}', visible=True)] + + num_boxes = len(data['bbox_controls']) + data_list = [] + for i in range(BBOX_MAX_NUM): + if i < num_boxes: + for k in BBoxSettings._fields: + if k in data['bbox_controls'][i]: + data_list.append(data['bbox_controls'][i][k]) + else: + data_list.append(None) + else: + data_list.extend(DEFAULT_BBOX_SETTINGS) + + return [gr_value(v) for v in data_list] + [gr_value(f'Config loaded from {fp}.', visible=True)] + + + def noise_inverse_set_cache(self, p: ProcessingImg2Img, x0: Tensor, xt: Tensor, prompts: List[str], steps: int, retouch:float): + self.noise_inverse_cache = NoiseInverseCache(p.sd_model.sd_model_hash, x0, xt, steps, retouch, prompts) + + def noise_inverse_get_cache(self): + return self.noise_inverse_cache + + + def reset(self): + ''' unhijack inner APIs, see hijack in process() ''' + if hasattr(Script, "create_sampler_original_md"): + sd_samplers.create_sampler = Script.create_sampler_original_md + del Script.create_sampler_original_md + if hasattr(Script, "create_random_tensors_original_md"): + processing.create_random_tensors = Script.create_random_tensors_original_md + del Script.create_random_tensors_original_md + if hasattr(sd_samplers_common.Sampler, "callback_ori"): + sd_samplers_common.Sampler.callback_state = sd_samplers_common.Sampler.callback_ori + del sd_samplers_common.Sampler.callback_ori + if hasattr(processing, "create_infotext_ori"): + processing.create_infotext = processing.create_infotext_ori + del processing.create_infotext_ori + DemoFusion.unhook() + self.delegate = None + + def reset_and_gc(self): + self.reset() + self.noise_inverse_cache = None + + import gc; gc.collect() + devices.torch_gc() + + try: + import os + import psutil + mem = psutil.Process(os.getpid()).memory_info() + print(f'[Mem] rss: {mem.rss/2**30:.3f} GB, vms: {mem.vms/2**30:.3f} GB') + from modules.shared import mem_mon as vram_mon + from modules.memmon import MemUsageMonitor + vram_mon: MemUsageMonitor + free, total = vram_mon.cuda_mem_get_info() + print(f'[VRAM] free: {free/2**30:.3f} GB, total: {total/2**30:.3f} GB') + except: + pass diff --git a/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/scripts/tilevae.py b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/scripts/tilevae.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4ed4a2e3b0e0213911cb8d289fb2d36d974d9c --- /dev/null +++ b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/scripts/tilevae.py @@ -0,0 +1,758 @@ +''' +# ------------------------------------------------------------------------ +# +# Tiled VAE +# +# Introducing a revolutionary new optimization designed to make +# the VAE work with giant images on limited VRAM! +# Say goodbye to the frustration of OOM and hello to seamless output! +# +# ------------------------------------------------------------------------ +# +# This script is a wild hack that splits the image into tiles, +# encodes each tile separately, and merges the result back together. +# +# Advantages: +# - The VAE can now work with giant images on limited VRAM +# (~10 GB for 8K images!) +# - The merged output is completely seamless without any post-processing. +# +# Drawbacks: +# - NaNs always appear in for 8k images when you use fp16 (half) VAE +# You must use --no-half-vae to disable half VAE for that giant image. +# - The gradient calculation is not compatible with this hack. It +# will break any backward() or torch.autograd.grad() that passes VAE. +# (But you can still use the VAE to generate training data.) +# +# How it works: +# 1. The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder. +# 2. When Fast Mode is disabled: +# 1. The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile. +# 2. When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile. +# 3. After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues. +# 4. A zigzag execution order is used to reduce unnecessary data transfer. +# 3. When Fast Mode is enabled: +# 1. The original input is downsampled and passed to a separate task queue. +# 2. Its group norm parameters are recorded and used by all tiles' task queues. +# 3. Each tile is separately processed without any RAM-VRAM data transfer. +# 4. After all tiles are processed, tiles are written to a result buffer and returned. +# Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode. +# +# Enjoy! +# +# @Author: LI YI @ Nanyang Technological University - Singapore +# @Date: 2023-03-02 +# @License: CC BY-NC-SA 4.0 +# +# Please give me a star if you like this project! +# +# ------------------------------------------------------------------------- +''' + +import gc +import math +from time import time +from tqdm import tqdm + +import torch +import torch.version +import torch.nn.functional as F +import gradio as gr + +import modules.scripts as scripts +import modules.devices as devices +from modules.shared import state, opts +from modules.ui import gr_show +from modules.processing import opt_f +from modules.sd_vae_approx import cheap_approximation +from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock + +from tile_utils.attn import get_attn_func +from tile_utils.typing import Processing + +if hasattr(opts, 'hypertile_enable_unet'): # webui >= 1.7 + from modules.ui_components import InputAccordion +else: + InputAccordion = None + + +def get_rcmd_enc_tsize(): + if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]: + total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20 + if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072 + elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048 + elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536 + else: ENCODER_TILE_SIZE = 960 + else: ENCODER_TILE_SIZE = 512 + return ENCODER_TILE_SIZE + + +def get_rcmd_dec_tsize(): + if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]: + total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20 + if total_memory > 30*1000: DECODER_TILE_SIZE = 256 + elif total_memory > 16*1000: DECODER_TILE_SIZE = 192 + elif total_memory > 12*1000: DECODER_TILE_SIZE = 128 + elif total_memory > 8*1000: DECODER_TILE_SIZE = 96 + else: DECODER_TILE_SIZE = 64 + else: DECODER_TILE_SIZE = 64 + return DECODER_TILE_SIZE + + +def inplace_nonlinearity(x): + # Test: fix for Nans + return F.silu(x, inplace=True) + + +def attn2task(task_queue, net): + attn_forward = get_attn_func() + task_queue.append(('store_res', lambda x: x)) + task_queue.append(('pre_norm', net.norm)) + task_queue.append(('attn', lambda x, net=net: attn_forward(net, x))) + task_queue.append(['add_res', None]) + + +def resblock2task(queue, block): + """ + Turn a ResNetBlock into a sequence of tasks and append to the task queue + + @param queue: the target task queue + @param block: ResNetBlock + + """ + if block.in_channels != block.out_channels: + if block.use_conv_shortcut: + queue.append(('store_res', block.conv_shortcut)) + else: + queue.append(('store_res', block.nin_shortcut)) + else: + queue.append(('store_res', lambda x: x)) + queue.append(('pre_norm', block.norm1)) + queue.append(('silu', inplace_nonlinearity)) + queue.append(('conv1', block.conv1)) + queue.append(('pre_norm', block.norm2)) + queue.append(('silu', inplace_nonlinearity)) + queue.append(('conv2', block.conv2)) + queue.append(['add_res', None]) + + +def build_sampling(task_queue, net, is_decoder): + """ + Build the sampling part of a task queue + @param task_queue: the target task queue + @param net: the network + @param is_decoder: currently building decoder or encoder + """ + if is_decoder: + resblock2task(task_queue, net.mid.block_1) + attn2task(task_queue, net.mid.attn_1) + resblock2task(task_queue, net.mid.block_2) + resolution_iter = reversed(range(net.num_resolutions)) + block_ids = net.num_res_blocks + 1 + condition = 0 + module = net.up + func_name = 'upsample' + else: + resolution_iter = range(net.num_resolutions) + block_ids = net.num_res_blocks + condition = net.num_resolutions - 1 + module = net.down + func_name = 'downsample' + + for i_level in resolution_iter: + for i_block in range(block_ids): + resblock2task(task_queue, module[i_level].block[i_block]) + if i_level != condition: + task_queue.append((func_name, getattr(module[i_level], func_name))) + + if not is_decoder: + resblock2task(task_queue, net.mid.block_1) + attn2task(task_queue, net.mid.attn_1) + resblock2task(task_queue, net.mid.block_2) + + +def build_task_queue(net, is_decoder): + """ + Build a single task queue for the encoder or decoder + @param net: the VAE decoder or encoder network + @param is_decoder: currently building decoder or encoder + @return: the task queue + """ + task_queue = [] + task_queue.append(('conv_in', net.conv_in)) + + # construct the sampling part of the task queue + # because encoder and decoder share the same architecture, we extract the sampling part + build_sampling(task_queue, net, is_decoder) + + if not is_decoder or not net.give_pre_end: + task_queue.append(('pre_norm', net.norm_out)) + task_queue.append(('silu', inplace_nonlinearity)) + task_queue.append(('conv_out', net.conv_out)) + if is_decoder and net.tanh_out: + task_queue.append(('tanh', torch.tanh)) + + return task_queue + + +def clone_task_queue(task_queue): + """ + Clone a task queue + @param task_queue: the task queue to be cloned + @return: the cloned task queue + """ + return [[item for item in task] for task in task_queue] + + +def get_var_mean(input, num_groups, eps=1e-6): + """ + Get mean and var for group norm + """ + b, c = input.size(0), input.size(1) + channel_in_group = int(c/num_groups) + input_reshaped = input.contiguous().view(1, int(b * num_groups), channel_in_group, *input.size()[2:]) + var, mean = torch.var_mean(input_reshaped, dim=[0, 2, 3, 4], unbiased=False) + return var, mean + + +def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6): + """ + Custom group norm with fixed mean and var + + @param input: input tensor + @param num_groups: number of groups. by default, num_groups = 32 + @param mean: mean, must be pre-calculated by get_var_mean + @param var: var, must be pre-calculated by get_var_mean + @param weight: weight, should be fetched from the original group norm + @param bias: bias, should be fetched from the original group norm + @param eps: epsilon, by default, eps = 1e-6 to match the original group norm + + @return: normalized tensor + """ + b, c = input.size(0), input.size(1) + channel_in_group = int(c/num_groups) + input_reshaped = input.contiguous().view( + 1, int(b * num_groups), channel_in_group, *input.size()[2:]) + + out = F.batch_norm(input_reshaped, mean.to(input), var.to(input), weight=None, bias=None, training=False, momentum=0, eps=eps) + out = out.view(b, c, *input.size()[2:]) + + # post affine transform + if weight is not None: + out *= weight.view(1, -1, 1, 1) + if bias is not None: + out += bias.view(1, -1, 1, 1) + return out + + +def crop_valid_region(x, input_bbox, target_bbox, is_decoder): + """ + Crop the valid region from the tile + @param x: input tile + @param input_bbox: original input bounding box + @param target_bbox: output bounding box + @param scale: scale factor + @return: cropped tile + """ + padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox] + margin = [target_bbox[i] - padded_bbox[i] for i in range(4)] + return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]] + + +# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓ + +def perfcount(fn): + def wrapper(*args, **kwargs): + ts = time() + + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats(devices.device) + devices.torch_gc() + gc.collect() + + ret = fn(*args, **kwargs) + + devices.torch_gc() + gc.collect() + if torch.cuda.is_available(): + vram = torch.cuda.max_memory_allocated(devices.device) / 2**20 + print(f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB') + else: + print(f'[Tiled VAE]: Done in {time() - ts:.3f}s') + + return ret + return wrapper + +# ↑↑↑ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↑↑↑ + + +class GroupNormParam: + + def __init__(self): + self.var_list = [] + self.mean_list = [] + self.pixel_list = [] + self.weight = None + self.bias = None + + def add_tile(self, tile, layer): + var, mean = get_var_mean(tile, 32) + # For giant images, the variance can be larger than max float16 + # In this case we create a copy to float32 + if var.dtype == torch.float16 and var.isinf().any(): + fp32_tile = tile.float() + var, mean = get_var_mean(fp32_tile, 32) + # ============= DEBUG: test for infinite ============= + # if torch.isinf(var).any(): + # print('var: ', var) + # ==================================================== + self.var_list.append(var) + self.mean_list.append(mean) + self.pixel_list.append( + tile.shape[2]*tile.shape[3]) + if hasattr(layer, 'weight'): + self.weight = layer.weight + self.bias = layer.bias + else: + self.weight = None + self.bias = None + + def summary(self): + """ + summarize the mean and var and return a function + that apply group norm on each tile + """ + if len(self.var_list) == 0: return None + + var = torch.vstack(self.var_list) + mean = torch.vstack(self.mean_list) + max_value = max(self.pixel_list) + pixels = torch.tensor(self.pixel_list, dtype=torch.float32, device=devices.device) / max_value + sum_pixels = torch.sum(pixels) + pixels = pixels.unsqueeze(1) / sum_pixels + var = torch.sum(var * pixels, dim=0) + mean = torch.sum(mean * pixels, dim=0) + return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias) + + @staticmethod + def from_tile(tile, norm): + """ + create a function from a single tile without summary + """ + var, mean = get_var_mean(tile, 32) + if var.dtype == torch.float16 and var.isinf().any(): + fp32_tile = tile.float() + var, mean = get_var_mean(fp32_tile, 32) + # if it is a macbook, we need to convert back to float16 + if var.device.type == 'mps': + # clamp to avoid overflow + var = torch.clamp(var, 0, 60000) + var = var.half() + mean = mean.half() + if hasattr(norm, 'weight'): + weight = norm.weight + bias = norm.bias + else: + weight = None + bias = None + + def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias): + return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6) + return group_norm_func + + +class VAEHook: + + def __init__(self, net, tile_size, is_decoder:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool, to_gpu:bool=False): + self.net = net # encoder | decoder + self.tile_size = tile_size + self.is_decoder = is_decoder + self.fast_mode = (fast_encoder and not is_decoder) or (fast_decoder and is_decoder) + self.color_fix = color_fix and not is_decoder + self.to_gpu = to_gpu + self.pad = 11 if is_decoder else 32 # FIXME: magic number + + def __call__(self, x): + original_device = next(self.net.parameters()).device + try: + if self.to_gpu: + self.net = self.net.to(devices.get_optimal_device()) + + B, C, H, W = x.shape + if max(H, W) <= self.pad * 2 + self.tile_size: + print("[Tiled VAE]: the input size is tiny and unnecessary to tile.") + return self.net.original_forward(x) + else: + return self.vae_tile_forward(x) + finally: + self.net = self.net.to(original_device) + + def get_best_tile_size(self, lowerbound, upperbound): + """ + Get the best tile size for GPU memory + """ + divider = 32 + while divider >= 2: + remainer = lowerbound % divider + if remainer == 0: + return lowerbound + candidate = lowerbound - remainer + divider + if candidate <= upperbound: + return candidate + divider //= 2 + return lowerbound + + def split_tiles(self, h, w): + """ + Tool function to split the image into tiles + @param h: height of the image + @param w: width of the image + @return: tile_input_bboxes, tile_output_bboxes + """ + tile_input_bboxes, tile_output_bboxes = [], [] + tile_size = self.tile_size + pad = self.pad + num_height_tiles = math.ceil((h - 2 * pad) / tile_size) + num_width_tiles = math.ceil((w - 2 * pad) / tile_size) + # If any of the numbers are 0, we let it be 1 + # This is to deal with long and thin images + num_height_tiles = max(num_height_tiles, 1) + num_width_tiles = max(num_width_tiles, 1) + + # Suggestions from https://github.com/Kahsolt: auto shrink the tile size + real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles) + real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles) + real_tile_height = self.get_best_tile_size(real_tile_height, tile_size) + real_tile_width = self.get_best_tile_size(real_tile_width, tile_size) + + print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' + + f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}') + + for i in range(num_height_tiles): + for j in range(num_width_tiles): + # bbox: [x1, x2, y1, y2] + # the padding is is unnessary for image borders. So we directly start from (32, 32) + input_bbox = [ + pad + j * real_tile_width, + min(pad + (j + 1) * real_tile_width, w), + pad + i * real_tile_height, + min(pad + (i + 1) * real_tile_height, h), + ] + + # if the output bbox is close to the image boundary, we extend it to the image boundary + output_bbox = [ + input_bbox[0] if input_bbox[0] > pad else 0, + input_bbox[1] if input_bbox[1] < w - pad else w, + input_bbox[2] if input_bbox[2] > pad else 0, + input_bbox[3] if input_bbox[3] < h - pad else h, + ] + + # scale to get the final output bbox + output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox] + tile_output_bboxes.append(output_bbox) + + # indistinguishable expand the input bbox by pad pixels + tile_input_bboxes.append([ + max(0, input_bbox[0] - pad), + min(w, input_bbox[1] + pad), + max(0, input_bbox[2] - pad), + min(h, input_bbox[3] + pad), + ]) + + return tile_input_bboxes, tile_output_bboxes + + @torch.no_grad() + def estimate_group_norm(self, z, task_queue, color_fix): + device = z.device + tile = z + last_id = len(task_queue) - 1 + while last_id >= 0 and task_queue[last_id][0] != 'pre_norm': + last_id -= 1 + if last_id <= 0 or task_queue[last_id][0] != 'pre_norm': + raise ValueError('No group norm found in the task queue') + # estimate until the last group norm + for i in range(last_id + 1): + task = task_queue[i] + if task[0] == 'pre_norm': + group_norm_func = GroupNormParam.from_tile(tile, task[1]) + task_queue[i] = ('apply_norm', group_norm_func) + if i == last_id: + return True + tile = group_norm_func(tile) + elif task[0] == 'store_res': + task_id = i + 1 + while task_id < last_id and task_queue[task_id][0] != 'add_res': + task_id += 1 + if task_id >= last_id: + continue + task_queue[task_id][1] = task[1](tile) + elif task[0] == 'add_res': + tile += task[1].to(device) + task[1] = None + elif color_fix and task[0] == 'downsample': + for j in range(i, last_id + 1): + if task_queue[j][0] == 'store_res': + task_queue[j] = ('store_res_cpu', task_queue[j][1]) + return True + else: + tile = task[1](tile) + try: + devices.test_for_nans(tile, "vae") + except: + print(f'Nan detected in fast mode estimation. Fast mode disabled.') + return False + + raise IndexError('Should not reach here') + + @perfcount + @torch.no_grad() + def vae_tile_forward(self, z): + """ + Decode a latent vector z into an image in a tiled manner. + @param z: latent vector + @return: image + """ + device = next(self.net.parameters()).device + dtype = next(self.net.parameters()).dtype + net = self.net + tile_size = self.tile_size + is_decoder = self.is_decoder + + z = z.detach() # detach the input to avoid backprop + + N, height, width = z.shape[0], z.shape[2], z.shape[3] + net.last_z_shape = z.shape + + # Split the input into tiles and build a task queue for each tile + print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}') + + in_bboxes, out_bboxes = self.split_tiles(height, width) + + # Prepare tiles by split the input latents + tiles = [] + for input_bbox in in_bboxes: + tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu() + tiles.append(tile) + + num_tiles = len(tiles) + num_completed = 0 + + # Build task queues + single_task_queue = build_task_queue(net, is_decoder) + if self.fast_mode: + # Fast mode: downsample the input image to the tile size, + # then estimate the group norm parameters on the downsampled image + scale_factor = tile_size / max(height, width) + z = z.to(device) + downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact') + # use nearest-exact to keep statictics as close as possible + print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image') + + # ======= Special thanks to @Kahsolt for distribution shift issue ======= # + # The downsampling will heavily distort its mean and std, so we need to recover it. + std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True) + std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True) + downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old + del std_old, mean_old, std_new, mean_new + # occasionally the std_new is too small or too large, which exceeds the range of float16 + # so we need to clamp it to max z's range. + downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max()) + estimate_task_queue = clone_task_queue(single_task_queue) + if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix): + single_task_queue = estimate_task_queue + del downsampled_z + + task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)] + + # Dummy result + result = None + result_approx = None + try: + with devices.autocast(): + result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu() + except: pass + # Free memory of input latent tensor + del z + + # Task queue execution + pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ") + + # execute the task back and forth when switch tiles so that we always + # keep one tile on the GPU to reduce unnecessary data transfer + forward = True + interrupted = False + #state.interrupted = interrupted + while True: + if state.interrupted: interrupted = True ; break + + group_norm_param = GroupNormParam() + for i in range(num_tiles) if forward else reversed(range(num_tiles)): + if state.interrupted: interrupted = True ; break + + tile = tiles[i].to(device) + input_bbox = in_bboxes[i] + task_queue = task_queues[i] + + interrupted = False + while len(task_queue) > 0: + if state.interrupted: interrupted = True ; break + + # DEBUG: current task + # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape) + task = task_queue.pop(0) + if task[0] == 'pre_norm': + group_norm_param.add_tile(tile, task[1]) + break + elif task[0] == 'store_res' or task[0] == 'store_res_cpu': + task_id = 0 + res = task[1](tile) + if not self.fast_mode or task[0] == 'store_res_cpu': + res = res.cpu() + while task_queue[task_id][0] != 'add_res': + task_id += 1 + task_queue[task_id][1] = res + elif task[0] == 'add_res': + tile += task[1].to(device) + task[1] = None + else: + tile = task[1](tile) + pbar.update(1) + + if interrupted: break + + # check for NaNs in the tile. + # If there are NaNs, we abort the process to save user's time + devices.test_for_nans(tile, "vae") + + if len(task_queue) == 0: + tiles[i] = None + num_completed += 1 + if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically + result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False) + result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder) + del tile + elif i == num_tiles - 1 and forward: + forward = False + tiles[i] = tile + elif i == 0 and not forward: + forward = True + tiles[i] = tile + else: + tiles[i] = tile.cpu() + del tile + + if interrupted: break + if num_completed == num_tiles: break + + # insert the group norm task to the head of each task queue + group_norm_func = group_norm_param.summary() + if group_norm_func is not None: + for i in range(num_tiles): + task_queue = task_queues[i] + task_queue.insert(0, ('apply_norm', group_norm_func)) + + # Done! + pbar.close() + return result.to(dtype) if result is not None else result_approx.to(device, dtype=dtype) + + +class Script(scripts.Script): + + def __init__(self): + self.hooked = False + + def title(self): + return "Tiled VAE" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + tab = 't2i' if not is_img2img else 'i2i' + uid = lambda name: f'MD-{tab}-{name}' + + with ( + InputAccordion(False, label='Tiled VAE', elem_id=f'MDV-{tab}-enabled') if InputAccordion + else gr.Accordion('Tiled VAE', open=False, elem_id=f'MDV-{tab}') + as enabled + ): + with gr.Row() as tab_enable: + if not InputAccordion: + enabled = gr.Checkbox(label='Enable Tiled VAE', value=False, elem_id=uid('enable')) + vae_to_gpu = gr.Checkbox(label='Move VAE to GPU (if possible)', value=True, elem_id=uid('vae2gpu')) + + gr.HTML('Recommended to set tile sizes as large as possible before got CUDA error: out of memory.
') + with gr.Row() as tab_size: + encoder_tile_size = gr.Slider(label='Encoder Tile Size', minimum=256, maximum=4096, step=16, value=get_rcmd_enc_tsize(), elem_id=uid('enc-size')) + decoder_tile_size = gr.Slider(label='Decoder Tile Size', minimum=48, maximum=512, step=16, value=get_rcmd_dec_tsize(), elem_id=uid('dec-size')) + reset = gr.Button(value='↻ Reset', variant='tool') + reset.click(fn=lambda: [get_rcmd_enc_tsize(), get_rcmd_dec_tsize()], outputs=[encoder_tile_size, decoder_tile_size], show_progress=False) + + with gr.Row() as tab_param: + fast_encoder = gr.Checkbox(label='Fast Encoder', value=True, elem_id=uid('fastenc')) + color_fix = gr.Checkbox(label='Fast Encoder Color Fix', value=False, visible=True, elem_id=uid('fastenc-colorfix')) + fast_decoder = gr.Checkbox(label='Fast Decoder', value=True, elem_id=uid('fastdec')) + + fast_encoder.change(fn=gr_show, inputs=fast_encoder, outputs=color_fix, show_progress=False) + + return [ + enabled, + encoder_tile_size, decoder_tile_size, + vae_to_gpu, fast_decoder, fast_encoder, color_fix, + ] + + def process(self, p:Processing, + enabled:bool, + encoder_tile_size:int, decoder_tile_size:int, + vae_to_gpu:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool + ): + + # for shorthand + vae = p.sd_model.first_stage_model + encoder = vae.encoder + decoder = vae.decoder + + # undo hijack if disabled (in cases last time crashed) + if not enabled: + if self.hooked: + if isinstance(encoder.forward, VAEHook): + encoder.forward.net = None + encoder.forward = encoder.original_forward + if isinstance(decoder.forward, VAEHook): + decoder.forward.net = None + decoder.forward = decoder.original_forward + self.hooked = False + return + + if devices.get_optimal_device_name().startswith('cuda') and vae.device == devices.cpu and not vae_to_gpu: + print("[Tiled VAE] warn: VAE is not on GPU, check 'Move VAE to GPU' if possible.") + + # do hijack + kwargs = { + 'fast_decoder': fast_decoder, + 'fast_encoder': fast_encoder, + 'color_fix': color_fix, + 'to_gpu': vae_to_gpu, + } + + # save original forward (only once) + if not hasattr(encoder, 'original_forward'): setattr(encoder, 'original_forward', encoder.forward) + if not hasattr(decoder, 'original_forward'): setattr(decoder, 'original_forward', decoder.forward) + + self.hooked = True + + encoder.forward = VAEHook(encoder, encoder_tile_size, is_decoder=False, **kwargs) + decoder.forward = VAEHook(decoder, decoder_tile_size, is_decoder=True, **kwargs) + + def postprocess(self, p:Processing, processed, enabled:bool, *args): + if not enabled: return + + vae = p.sd_model.first_stage_model + encoder = vae.encoder + decoder = vae.decoder + if isinstance(encoder.forward, VAEHook): + encoder.forward.net = None + encoder.forward = encoder.original_forward + if isinstance(decoder.forward, VAEHook): + decoder.forward.net = None + decoder.forward = decoder.original_forward diff --git a/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/abstractdiffusion.cpython-310.pyc b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/abstractdiffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8c2c8df35d3835feae2abbc545eb35b96d85b04 Binary files /dev/null and b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/abstractdiffusion.cpython-310.pyc differ diff --git a/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/demofusion.cpython-310.pyc b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/demofusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..888412ca6acb9ab5fc12845cf96b4bc9ad682f33 Binary files /dev/null and b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/demofusion.cpython-310.pyc differ diff --git a/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/mixtureofdiffusers.cpython-310.pyc b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/mixtureofdiffusers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d87d14e5afdcfec86f7aeb26e7e0d64a44602e31 Binary files /dev/null and b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/mixtureofdiffusers.cpython-310.pyc differ diff --git a/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/multidiffusion.cpython-310.pyc b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/multidiffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df9e59715f9a03a73815275ee019edf53cfcac92 Binary files /dev/null and b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/multidiffusion.cpython-310.pyc differ diff --git a/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/abstractdiffusion.py b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/abstractdiffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..88916f3881479355d6ac0ab0425836a09255e126 --- /dev/null +++ b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/abstractdiffusion.py @@ -0,0 +1,747 @@ +from tile_utils.utils import * + + +class AbstractDiffusion: + + def __init__(self, p: Processing, sampler: Sampler): + self.method = self.__class__.__name__ + self.p: Processing = p + self.pbar = None + + # sampler + self.sampler_name = p.sampler_name + self.sampler_raw = sampler + self.sampler = sampler + + # fix. Kdiff 'AND' support and image editing model support + if self.is_kdiff and not hasattr(self, 'is_edit_model'): + self.is_edit_model = (shared.sd_model.cond_stage_key == "edit" # "txt" + and self.sampler.model_wrap_cfg.image_cfg_scale is not None + and self.sampler.model_wrap_cfg.image_cfg_scale != 1.0) + + # cache. final result of current sampling step, [B, C=4, H//8, W//8] + # avoiding overhead of creating new tensors and weight summing + self.x_buffer: Tensor = None + self.w: int = int(self.p.width // opt_f) # latent size + self.h: int = int(self.p.height // opt_f) + # weights for background & grid bboxes + self.weights: Tensor = torch.zeros((1, 1, self.h, self.w), device=devices.device, dtype=torch.float32) + + # FIXME: I'm trying to count the step correctly but it's not working + self.step_count = 0 + self.inner_loop_count = 0 + self.kdiff_step = -1 + + # ext. Grid tiling painting (grid bbox) + self.enable_grid_bbox: bool = False + self.tile_w: int = None + self.tile_h: int = None + self.tile_bs: int = None + self.num_tiles: int = None + self.num_batches: int = None + self.batched_bboxes: List[List[BBox]] = [] + + # ext. Region Prompt Control (custom bbox) + self.enable_custom_bbox: bool = False + self.custom_bboxes: List[CustomBBox] = [] + self.cond_basis: Cond = None + self.uncond_basis: Uncond = None + self.draw_background: bool = True # by default we draw major prompts in grid tiles + self.causal_layers: bool = None + + # ext. Noise Inversion (noise inversion) + self.noise_inverse_enabled: bool = False + self.noise_inverse_steps: int = 0 + self.noise_inverse_retouch: float = None + self.noise_inverse_renoise_strength: float = None + self.noise_inverse_renoise_kernel: int = None + self.noise_inverse_get_cache = None + self.noise_inverse_set_cache = None + self.sample_img2img_original = None + + # ext. ControlNet + self.enable_controlnet: bool = False + self.controlnet_script: ModuleType = None + self.control_tensor_batch: List[List[Tensor]] = [] + self.control_params: Dict[str, Tensor] = {} + self.control_tensor_cpu: bool = None + self.control_tensor_custom: List[List[Tensor]] = [] + + # ext. StableSR + self.enable_stablesr: bool = False + self.stablesr_script: ModuleType = None + self.stablesr_tensor: Tensor = None + self.stablesr_tensor_batch: List[Tensor] = [] + self.stablesr_tensor_custom: List[Tensor] = [] + + @property + def is_kdiff(self): + return isinstance(self.sampler_raw, KDiffusionSampler) + + @property + def is_ddim(self): + return isinstance(self.sampler_raw, CompVisSampler) + + def update_pbar(self): + if self.pbar.n >= self.pbar.total: + self.pbar.close() + else: + if self.step_count == state.sampling_step: + self.inner_loop_count += 1 + if self.inner_loop_count < self.total_bboxes: + self.pbar.update() + else: + self.step_count = state.sampling_step + self.inner_loop_count = 0 + + def reset_buffer(self, x_in:Tensor): + # Judge if the shape of x_in is the same as the shape of x_buffer + if self.x_buffer is None or self.x_buffer.shape != x_in.shape: + self.x_buffer = torch.zeros_like(x_in, device=x_in.device, dtype=x_in.dtype) + else: + self.x_buffer.zero_() + + def init_done(self): + ''' + Call this after all `init_*`, settings are done, now perform: + - settings sanity check + - pre-computations, cache init + - anything thing needed before denoising starts + ''' + + self.total_bboxes = 0 + if self.enable_grid_bbox: self.total_bboxes += self.num_batches + if self.enable_custom_bbox: self.total_bboxes += len(self.custom_bboxes) + assert self.total_bboxes > 0, "Nothing to paint! No background to draw and no custom bboxes were provided." + + self.pbar = tqdm(total=(self.total_bboxes) * state.sampling_steps, desc=f"{self.method} Sampling: ") + + ''' ↓↓↓ cond_dict utils ↓↓↓ ''' + + def _tcond_key(self, cond_dict:CondDict) -> str: + return 'crossattn' if 'crossattn' in cond_dict else 'c_crossattn' + + def get_tcond(self, cond_dict:CondDict) -> Tensor: + tcond = cond_dict[self._tcond_key(cond_dict)] + if isinstance(tcond, list): tcond = tcond[0] + return tcond + + def set_tcond(self, cond_dict:CondDict, tcond:Tensor): + key = self._tcond_key(cond_dict) + if isinstance(cond_dict[key], list): tcond = [tcond] + cond_dict[key] = tcond + + def _icond_key(self, cond_dict:CondDict) -> str: + return 'c_adm' if shared.sd_model.model.conditioning_key in ['crossattn-adm', 'adm'] else 'c_concat' + + def get_icond(self, cond_dict:CondDict) -> Tensor: + ''' icond differs for different models (inpaint/unclip model) ''' + key = self._icond_key(cond_dict) + icond = cond_dict[key] + if isinstance(icond, list): icond = icond[0] + return icond + + def set_icond(self, cond_dict:CondDict, icond:Tensor): + key = self._icond_key(cond_dict) + if isinstance(cond_dict[key], list): icond = [icond] + cond_dict[key] = icond + + def _vcond_key(self, cond_dict:CondDict) -> Optional[str]: + return 'vector' if 'vector' in cond_dict else None + + def get_vcond(self, cond_dict:CondDict) -> Optional[Tensor]: + ''' vector for SDXL ''' + key = self._vcond_key(cond_dict) + return cond_dict.get(key) + + def set_vcond(self, cond_dict:CondDict, vcond:Optional[Tensor]): + key = self._vcond_key(cond_dict) + if key is not None: + cond_dict[key] = vcond + + def make_cond_dict(self, cond_in:CondDict, tcond:Tensor, icond:Tensor, vcond:Tensor=None) -> CondDict: + ''' copy & replace the content, returns a new object ''' + cond_out = cond_in.copy() + self.set_tcond(cond_out, tcond) + self.set_icond(cond_out, icond) + self.set_vcond(cond_out, vcond) + return cond_out + + ''' ↓↓↓ extensive functionality ↓↓↓ ''' + + @grid_bbox + def init_grid_bbox(self, tile_w:int, tile_h:int, overlap:int, tile_bs:int): + self.enable_grid_bbox = True + + self.tile_w = min(tile_w, self.w) + self.tile_h = min(tile_h, self.h) + overlap = max(0, min(overlap, min(tile_w, tile_h) - 4)) + # split the latent into overlapped tiles, then batching + # weights basically indicate how many times a pixel is painted + bboxes, weights = split_bboxes(self.w, self.h, self.tile_w, self.tile_h, overlap, self.get_tile_weights()) + self.weights += weights + self.num_tiles = len(bboxes) + self.num_batches = math.ceil(self.num_tiles / tile_bs) + self.tile_bs = math.ceil(len(bboxes) / self.num_batches) # optimal_batch_size + self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)] + + @grid_bbox + def get_tile_weights(self) -> Union[Tensor, float]: + return 1.0 + + + @custom_bbox + def init_custom_bbox(self, bbox_settings:Dict[int,BBoxSettings], draw_background:bool, causal_layers:bool): + self.enable_custom_bbox = True + + self.causal_layers = causal_layers + self.draw_background = draw_background + if not draw_background: + self.enable_grid_bbox = False + self.weights.zero_() + + self.custom_bboxes: List[CustomBBox] = [] + for bbox_setting in bbox_settings.values(): + e, x, y, w, h, p, n, blend_mode, feather_ratio, seed = bbox_setting + if not e or x > 1.0 or y > 1.0 or w <= 0.0 or h <= 0.0: continue + x = int(x * self.w) + y = int(y * self.h) + w = math.ceil(w * self.w) + h = math.ceil(h * self.h) + x = max(0, x) + y = max(0, y) + w = min(self.w - x, w) + h = min(self.h - y, h) + self.custom_bboxes.append(CustomBBox(x, y, w, h, p, n, blend_mode, feather_ratio, seed)) + + if len(self.custom_bboxes) == 0: + self.enable_custom_bbox = False + return + + # prepare cond + p = self.p + prompts = p.all_prompts[:p.batch_size] + neg_prompts = p.all_negative_prompts[:p.batch_size] + for bbox in self.custom_bboxes: + bbox.cond, bbox.extra_network_data = Condition.get_custom_cond(prompts, bbox.prompt, p.steps, p.styles) + bbox.uncond = Condition.get_uncond(Prompt.append_prompt(neg_prompts, bbox.neg_prompt), p.steps, p.styles) + self.cond_basis = Condition.get_cond(prompts, p.steps) + self.uncond_basis = Condition.get_uncond(neg_prompts, p.steps) + + @custom_bbox + def reconstruct_custom_cond(self, org_cond:CondDict, custom_cond:Cond, custom_uncond:Uncond, bbox:CustomBBox) -> Tuple[List, Tensor, Uncond, Tensor]: + image_conditioning = None + if isinstance(org_cond, dict): + icond = self.get_icond(org_cond) + if icond.shape[2:] == (self.h, self.w): # img2img + icond = icond[bbox.slicer] + image_conditioning = icond + + sampler_step = self.sampler.model_wrap_cfg.step + tensor = Condition.reconstruct_cond(custom_cond, sampler_step) + custom_uncond = Condition.reconstruct_uncond(custom_uncond, sampler_step) + return tensor, custom_uncond, image_conditioning + + @custom_bbox + def kdiff_custom_forward(self, x_tile:Tensor, sigma_in:Tensor, original_cond:CondDict, bbox_id:int, bbox:CustomBBox, forward_func:Callable) -> Tensor: + ''' + The inner kdiff noise prediction is usually batched. + We need to unwrap the inside loop to simulate the batched behavior. + This can be extremely tricky. + ''' + + sampler_step = self.sampler.model_wrap_cfg.step + if self.kdiff_step != sampler_step: + self.kdiff_step = sampler_step + self.kdiff_step_bbox = [-1 for _ in range(len(self.custom_bboxes))] + self.tensor = {} # {int: Tensor[cond]} + self.uncond = {} # {int: Tensor[cond]} + self.image_cond_in = {} + # Initialize global prompts just for estimate the behavior of kdiff + self.real_tensor = Condition.reconstruct_cond(self.cond_basis, sampler_step) + self.real_uncond = Condition.reconstruct_uncond(self.uncond_basis, sampler_step) + # reset the progress for all bboxes + self.a = [0 for _ in range(len(self.custom_bboxes))] + + if self.kdiff_step_bbox[bbox_id] != sampler_step: + # When a new step starts for a bbox, we need to judge whether the tensor is batched. + self.kdiff_step_bbox[bbox_id] = sampler_step + + tensor, uncond, image_cond_in = self.reconstruct_custom_cond(original_cond, bbox.cond, bbox.uncond, bbox) + + if self.real_tensor.shape[1] == self.real_uncond.shape[1]: + if shared.batch_cond_uncond: + # when the real tensor is with equal length, all information is contained in x_tile. + # we simulate the batched behavior and compute all the tensors in one go. + if tensor.shape[1] == uncond.shape[1]: + # When our prompt tensor is with equal length, we can directly their code. + if not self.is_edit_model: + cond = torch.cat([tensor, uncond]) + else: + cond = torch.cat([tensor, uncond, uncond]) + self.set_custom_controlnet_tensors(bbox_id, x_tile.shape[0]) + self.set_custom_stablesr_tensors(bbox_id) + return forward_func( + x_tile, + sigma_in, + cond=self.make_cond_dict(original_cond, cond, image_cond_in), + ) + else: + # When not, we need to pass the tensor to UNet separately. + x_out = torch.zeros_like(x_tile) + cond_size = tensor.shape[0] + self.set_custom_controlnet_tensors(bbox_id, cond_size) + self.set_custom_stablesr_tensors(bbox_id) + cond_out = forward_func( + x_tile [:cond_size], + sigma_in[:cond_size], + cond=self.make_cond_dict(original_cond, tensor, image_cond_in[:cond_size]), + ) + uncond_size = uncond.shape[0] + self.set_custom_controlnet_tensors(bbox_id, uncond_size) + self.set_custom_stablesr_tensors(bbox_id) + uncond_out = forward_func( + x_tile [cond_size:cond_size+uncond_size], + sigma_in[cond_size:cond_size+uncond_size], + cond=self.make_cond_dict(original_cond, uncond, image_cond_in[cond_size:cond_size+uncond_size]), + ) + x_out[:cond_size] = cond_out + x_out[cond_size:cond_size+uncond_size] = uncond_out + if self.is_edit_model: + x_out[cond_size+uncond_size:] = uncond_out + return x_out + + # otherwise, the x_tile is only a partial batch. + # We have to denoise in different runs. + # We store the prompt and neg_prompt tensors for current bbox + self.tensor[bbox_id] = tensor + self.uncond[bbox_id] = uncond + self.image_cond_in[bbox_id] = image_cond_in + + # Now we get current batch of prompt and neg_prompt tensors + tensor: Tensor = self.tensor[bbox_id] + uncond: Tensor = self.uncond[bbox_id] + batch_size = x_tile.shape[0] + # get the start and end index of the current batch + a = self.a[bbox_id] + b = a + batch_size + self.a[bbox_id] += batch_size + + if self.real_tensor.shape[1] == self.real_uncond.shape[1]: + # When use --lowvram or --medvram, kdiff will slice the cond and uncond with [a:b] + # So we need to slice our tensor and uncond with the same index as original kdiff. + + # --- original code in kdiff --- + # if not self.is_edit_model: + # cond = torch.cat([tensor, uncond]) + # else: + # cond = torch.cat([tensor, uncond, uncond]) + # cond = cond[a:b] + # ------------------------------ + + # The original kdiff code is to concat and then slice, but this cannot apply to + # our custom prompt tensor when tensor.shape[1] != uncond.shape[1]. So we adapt it. + cond_in, uncond_in = None, None + # Slice the [prompt, neg prompt, (possibly) neg prompt] with [a:b] + if not self.is_edit_model: + if b <= tensor.shape[0]: cond_in = tensor[a:b] + elif a >= tensor.shape[0]: cond_in = uncond[a-tensor.shape[0]:b-tensor.shape[0]] + else: + cond_in = tensor[a:] + uncond_in = uncond[:b-tensor.shape[0]] + else: + if b <= tensor.shape[0]: + cond_in = tensor[a:b] + elif b > tensor.shape[0] and b <= tensor.shape[0] + uncond.shape[0]: + if a>= tensor.shape[0]: + cond_in = uncond[a-tensor.shape[0]:b-tensor.shape[0]] + else: + cond_in = tensor[a:] + uncond_in = uncond[:b-tensor.shape[0]] + else: + if a >= tensor.shape[0] + uncond.shape[0]: + cond_in = uncond[a-tensor.shape[0]-uncond.shape[0]:b-tensor.shape[0]-uncond.shape[0]] + elif a >= tensor.shape[0]: + cond_in = torch.cat([uncond[a-tensor.shape[0]:], uncond[:b-tensor.shape[0]-uncond.shape[0]]]) + + if uncond_in is None or tensor.shape[1] == uncond.shape[1]: + # If the tensor can be passed to UNet in one go, do it. + if uncond_in is not None: + cond_in = torch.cat([cond_in, uncond_in]) + self.set_custom_controlnet_tensors(bbox_id, x_tile.shape[0]) + self.set_custom_stablesr_tensors(bbox_id) + return forward_func( + x_tile, + sigma_in, + cond=self.make_cond_dict(original_cond, cond_in, self.image_cond_in[bbox_id]), + ) + else: + # If not, we need to pass the tensor to UNet separately. + x_out = torch.zeros_like(x_tile) + cond_size = cond_in.shape[0] + self.set_custom_controlnet_tensors(bbox_id, cond_size) + self.set_custom_stablesr_tensors(bbox_id) + cond_out = forward_func( + x_tile [:cond_size], + sigma_in[:cond_size], + cond=self.make_cond_dict(original_cond, cond_in, self.image_cond_in[bbox_id]) + ) + self.set_custom_controlnet_tensors(bbox_id, uncond_in.shape[0]) + self.set_custom_stablesr_tensors(bbox_id) + uncond_out = forward_func( + x_tile [cond_size:], + sigma_in[cond_size:], + cond=self.make_cond_dict(original_cond, uncond_in, self.image_cond_in[bbox_id]) + ) + x_out[:cond_size] = cond_out + x_out[cond_size:] = uncond_out + return x_out + + # If the original prompt is with different length, + # kdiff will deal with the cond and uncond separately. + # Hence we also deal with the tensor and uncond separately. + # get the start and end index of the current batch + + if a < tensor.shape[0]: + # Deal with custom prompt tensor + if not self.is_edit_model: + c_crossattn = tensor[a:b] + else: + c_crossattn = torch.cat([tensor[a:b]], uncond) + self.set_custom_controlnet_tensors(bbox_id, x_tile.shape[0]) + self.set_custom_stablesr_tensors(bbox_id) + # complete this batch. + return forward_func( + x_tile, + sigma_in, + cond=self.make_cond_dict(original_cond, c_crossattn, self.image_cond_in[bbox_id]) + ) + else: + # if the cond is finished, we need to process the uncond. + self.set_custom_controlnet_tensors(bbox_id, uncond.shape[0]) + self.set_custom_stablesr_tensors(bbox_id) + return forward_func( + x_tile, + sigma_in, + cond=self.make_cond_dict(original_cond, uncond, self.image_cond_in[bbox_id]) + ) + + @custom_bbox + def ddim_custom_forward(self, x:Tensor, cond_in:CondDict, bbox:CustomBBox, ts:Tensor, forward_func:Callable, *args, **kwargs) -> Tensor: + ''' draw custom bbox ''' + + tensor, uncond, image_conditioning = self.reconstruct_custom_cond(cond_in, bbox.cond, bbox.uncond, bbox) + + cond = tensor + # for DDIM, shapes definitely match. So we dont need to do the same thing as in the KDIFF sampler. + if uncond.shape[1] < cond.shape[1]: + last_vector = uncond[:, -1:] + last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond.shape[1], 1]) + uncond = torch.hstack([uncond, last_vector_repeated]) + elif uncond.shape[1] > cond.shape[1]: + uncond = uncond[:, :cond.shape[1]] + + # Wrap the image conditioning back up since the DDIM code can accept the dict directly. + # Note that they need to be lists because it just concatenates them later. + if image_conditioning is not None: + cond = self.make_cond_dict(cond_in, cond, image_conditioning) + uncond = self.make_cond_dict(cond_in, uncond, image_conditioning) + + # We cannot determine the batch size here for different methods, so delay it to the forward_func. + return forward_func(x, cond, ts, unconditional_conditioning=uncond, *args, **kwargs) + + + @controlnet + def init_controlnet(self, controlnet_script:ModuleType, control_tensor_cpu:bool): + self.enable_controlnet = True + + self.controlnet_script = controlnet_script + self.control_tensor_cpu = control_tensor_cpu + self.control_tensor_batch = None + self.control_params = None + self.control_tensor_custom = [] + + self.prepare_controlnet_tensors() + + @controlnet + def reset_controlnet_tensors(self): + if not self.enable_controlnet: return + if self.control_tensor_batch is None: return + + for param_id in range(len(self.control_params)): + self.control_params[param_id].hint_cond = self.org_control_tensor_batch[param_id] + + @controlnet + def prepare_controlnet_tensors(self, refresh:bool=False): + ''' Crop the control tensor into tiles and cache them ''' + + if not refresh: + if self.control_tensor_batch is not None or self.control_params is not None: return + + if not self.enable_controlnet or self.controlnet_script is None: return + + latest_network = self.controlnet_script.latest_network + if latest_network is None or not hasattr(latest_network, 'control_params'): return + + self.control_params = latest_network.control_params + tensors = [param.hint_cond for param in latest_network.control_params] + self.org_control_tensor_batch = tensors + + if len(tensors) == 0: return + + self.control_tensor_batch = [] + for i in range(len(tensors)): + control_tile_list = [] + control_tensor = tensors[i] + for bboxes in self.batched_bboxes: + single_batch_tensors = [] + for bbox in bboxes: + if len(control_tensor.shape) == 3: + control_tensor.unsqueeze_(0) + control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f] + single_batch_tensors.append(control_tile) + control_tile = torch.cat(single_batch_tensors, dim=0) + if self.control_tensor_cpu: + control_tile = control_tile.cpu() + control_tile_list.append(control_tile) + self.control_tensor_batch.append(control_tile_list) + + if len(self.custom_bboxes) > 0: + custom_control_tile_list = [] + for bbox in self.custom_bboxes: + if len(control_tensor.shape) == 3: + control_tensor.unsqueeze_(0) + control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f] + if self.control_tensor_cpu: + control_tile = control_tile.cpu() + custom_control_tile_list.append(control_tile) + self.control_tensor_custom.append(custom_control_tile_list) + + @controlnet + def switch_controlnet_tensors(self, batch_id:int, x_batch_size:int, tile_batch_size:int, is_denoise=False): + if not self.enable_controlnet: return + if self.control_tensor_batch is None: return + + for param_id in range(len(self.control_params)): + control_tile = self.control_tensor_batch[param_id][batch_id] + if self.is_kdiff: + all_control_tile = [] + for i in range(tile_batch_size): + this_control_tile = [control_tile[i].unsqueeze(0)] * x_batch_size + all_control_tile.append(torch.cat(this_control_tile, dim=0)) + control_tile = torch.cat(all_control_tile, dim=0) + else: + control_tile = control_tile.repeat([x_batch_size if is_denoise else x_batch_size * 2, 1, 1, 1]) + self.control_params[param_id].hint_cond = control_tile.to(devices.device) + + @controlnet + def set_custom_controlnet_tensors(self, bbox_id:int, repeat_size:int): + if not self.enable_controlnet: return + if not len(self.control_tensor_custom): return + + for param_id in range(len(self.control_params)): + control_tensor = self.control_tensor_custom[param_id][bbox_id].to(devices.device) + self.control_params[param_id].hint_cond = control_tensor.repeat((repeat_size, 1, 1, 1)) + + + @stablesr + def init_stablesr(self, stablesr_script:ModuleType): + if stablesr_script.stablesr_model is None: return + self.stablesr_script = stablesr_script + def set_image_hook(latent_image): + self.enable_stablesr = True + self.stablesr_tensor = latent_image + self.stablesr_tensor_batch = [] + for bboxes in self.batched_bboxes: + single_batch_tensors = [] + for bbox in bboxes: + stablesr_tile = self.stablesr_tensor[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]] + single_batch_tensors.append(stablesr_tile) + stablesr_tile = torch.cat(single_batch_tensors, dim=0) + self.stablesr_tensor_batch.append(stablesr_tile) + if len(self.custom_bboxes) > 0: + self.stablesr_tensor_custom = [] + for bbox in self.custom_bboxes: + stablesr_tile = self.stablesr_tensor[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]] + self.stablesr_tensor_custom.append(stablesr_tile) + + stablesr_script.stablesr_model.set_image_hooks['TiledDiffusion'] = set_image_hook + + @stablesr + def reset_stablesr_tensors(self): + if not self.enable_stablesr: return + if self.stablesr_script.stablesr_model is None: return + self.stablesr_script.stablesr_model.latent_image = self.stablesr_tensor + + @stablesr + def switch_stablesr_tensors(self, batch_id:int): + if not self.enable_stablesr: return + if self.stablesr_script.stablesr_model is None: return + if self.stablesr_tensor_batch is None: return + self.stablesr_script.stablesr_model.latent_image = self.stablesr_tensor_batch[batch_id] + + @stablesr + def set_custom_stablesr_tensors(self, bbox_id:int): + if not self.enable_stablesr: return + if self.stablesr_script.stablesr_model is None: return + if not len(self.stablesr_tensor_custom): return + self.stablesr_script.stablesr_model.latent_image = self.stablesr_tensor_custom[bbox_id] + + + @noise_inverse + def init_noise_inverse(self, steps:int, retouch:float, get_cache_callback, set_cache_callback, renoise_strength:float, renoise_kernel:int): + self.noise_inverse_enabled = True + self.noise_inverse_steps = steps + self.noise_inverse_retouch = float(retouch) + self.noise_inverse_renoise_strength = float(renoise_strength) + self.noise_inverse_renoise_kernel = int(renoise_kernel) + if self.sample_img2img_original is None: + self.sample_img2img_original = self.sampler_raw.sample_img2img + self.sampler_raw.sample_img2img = MethodType(self.sample_img2img, self.sampler_raw) + self.noise_inverse_set_cache = set_cache_callback + self.noise_inverse_get_cache = get_cache_callback + + @noise_inverse + @keep_signature + def sample_img2img(self, sampler: KDiffusionSampler, p:ProcessingImg2Img, + x:Tensor, noise:Tensor, conditioning, unconditional_conditioning, + steps=None, image_conditioning=None): + # noise inverse sampling - renoise mask + import torch.nn.functional as F + renoise_mask = None + if self.noise_inverse_renoise_strength > 0: + image = p.init_images[0] + # convert to grayscale with PIL + image = image.convert('L') + np_mask = get_retouch_mask(np.asarray(image), self.noise_inverse_renoise_kernel) + renoise_mask = torch.from_numpy(np_mask).to(noise.device) + # resize retouch mask to match noise size + renoise_mask = 1 - F.interpolate(renoise_mask.unsqueeze(0).unsqueeze(0), size=noise.shape[-2:], mode='bilinear').squeeze(0).squeeze(0) + renoise_mask *= self.noise_inverse_renoise_strength + renoise_mask = torch.clamp(renoise_mask, 0, 1) + + prompts = p.all_prompts[:p.batch_size] + + latent = None + # try to use cached latent to save huge amount of time. + cached_latent: NoiseInverseCache = self.noise_inverse_get_cache() + if cached_latent is not None and \ + cached_latent.model_hash == p.sd_model.sd_model_hash and \ + cached_latent.noise_inversion_steps == self.noise_inverse_steps and \ + len(cached_latent.prompts) == len(prompts) and \ + all([cached_latent.prompts[i] == prompts[i] for i in range(len(prompts))]) and \ + abs(cached_latent.retouch - self.noise_inverse_retouch) < 0.01 and \ + cached_latent.x0.shape == p.init_latent.shape and \ + torch.abs(cached_latent.x0.to(p.init_latent.device) - p.init_latent).sum() < 100: # the 100 is an arbitrary threshold copy-pasted from the img2img alt code + # use cached noise + print('[Tiled Diffusion] Your checkpoint, image, prompts, inverse steps, and retouch params are all unchanged.') + print('[Tiled Diffusion] Noise Inversion will use the cached noise from the previous run. To clear the cache, click the Free GPU button.') + latent = cached_latent.xt.to(noise.device) + if latent is None: + # run noise inversion + shared.state.job_count += 1 + latent = self.find_noise_for_image_sigma_adjustment(sampler.model_wrap, self.noise_inverse_steps, prompts) + shared.state.nextjob() + self.noise_inverse_set_cache(p.init_latent.clone().cpu(), latent.clone().cpu(), prompts) + # The cache is only 1 latent image and is very small (16 MB for 8192 * 8192 image), so we don't need to worry about memory leakage. + + # calculate sampling steps + adjusted_steps, _ = sd_samplers_common.setup_img2img_steps(p, steps) + sigmas = sampler.get_sigmas(p, adjusted_steps) + inverse_noise = latent - (p.init_latent / sigmas[0]) + + # inject noise to high-frequency area so that the details won't lose too much + if renoise_mask is not None: + # If the background is not drawn, we need to filter out the un-drawn pixels and reweight foreground with feather mask + # This is to enable the renoise mask in regional inpainting + if not self.enable_grid_bbox: + background_count = torch.zeros((1, 1, noise.shape[2], noise.shape[3]), device=noise.device) + foreground_noise = torch.zeros_like(noise) + foreground_weight = torch.zeros((1, 1, noise.shape[2], noise.shape[3]), device=noise.device) + foreground_count = torch.zeros((1, 1, noise.shape[2], noise.shape[3]), device=noise.device) + for bbox in self.custom_bboxes: + if bbox.blend_mode == BlendMode.BACKGROUND: + background_count[bbox.slicer] += 1 + elif bbox.blend_mode == BlendMode.FOREGROUND: + foreground_noise [bbox.slicer] += noise[bbox.slicer] + foreground_weight[bbox.slicer] += bbox.feather_mask + foreground_count [bbox.slicer] += 1 + background_noise = torch.where(background_count > 0, noise, 0) + foreground_noise = torch.where(foreground_count > 0, foreground_noise / foreground_count, 0) + foreground_weight = torch.where(foreground_count > 0, foreground_weight / foreground_count, 0) + noise = background_noise * (1 - foreground_weight) + foreground_noise * foreground_weight + del background_noise, foreground_noise, foreground_weight, background_count, foreground_count + combined_noise = ((1 - renoise_mask) * inverse_noise + renoise_mask * noise) / ((renoise_mask**2 + (1 - renoise_mask)**2) ** 0.5) + else: + combined_noise = inverse_noise + + # use the estimated noise for the original img2img sampling + return self.sample_img2img_original(p, x, combined_noise, conditioning, unconditional_conditioning, steps, image_conditioning) + + @noise_inverse + @torch.no_grad() + def find_noise_for_image_sigma_adjustment(self, dnw, steps, prompts:List[str]) -> Tensor: + ''' + Migrate from the built-in script img2imgalt.py + Tiled noise inverse for better image upscaling + ''' + import k_diffusion as K + assert self.p.sampler_name == 'Euler' + + x = self.p.init_latent + s_in = x.new_ones([x.shape[0]]) + skip = 1 if shared.sd_model.parameterization == "v" else 0 + sigmas = dnw.get_sigmas(steps).flip(0) + + cond = self.p.sd_model.get_learned_conditioning(prompts) + if isinstance(cond, Tensor): # SD1/SD2 + cond_dict_dummy = { + 'c_crossattn': [], # List[Tensor] + 'c_concat': [], # List[Tensor] + } + cond_in = self.make_cond_dict(cond_dict_dummy, cond, self.p.image_conditioning) + else: # SDXL + cond_dict_dummy = { + 'crossattn': None, # Tensor + 'vector': None, # Tensor + 'c_concat': [], # List[Tensor] + } + cond_in = self.make_cond_dict(cond_dict_dummy, cond['crossattn'], self.p.image_conditioning, cond['vector']) + + state.sampling_steps = steps + pbar = tqdm(total=steps, desc='Noise Inversion') + for i in range(1, len(sigmas)): + if state.interrupted: return x + + state.sampling_step += 1 + + x_in = x + sigma_in = torch.cat([sigmas[i] * s_in]) + c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]] + + t = dnw.sigma_to_t(sigma_in) + t = t / self.noise_inverse_retouch + + eps = self.get_noise(x_in * c_in, t, cond_in, steps - i) + denoised = x_in + eps * c_out + + # Euler method: + d = (x - denoised) / sigmas[i] + dt = sigmas[i] - sigmas[i - 1] + x = x + d * dt + + sd_samplers_common.store_latent(x) + + # This is neccessary to save memory before the next iteration + del x_in, sigma_in, c_out, c_in, t, + del eps, denoised, d, dt + + pbar.update(1) + pbar.close() + + return x / sigmas[-1] + + @noise_inverse + @torch.no_grad() + def get_noise(self, x_in: Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor: + raise NotImplementedError diff --git a/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/demofusion.py b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/demofusion.py new file mode 100644 index 0000000000000000000000000000000000000000..758ccfe0d13c6e92b878660a475905cf325a29fd --- /dev/null +++ b/extensions/CHECK/multidiffusion-upscaler-for-automatic1111/tile_methods/demofusion.py @@ -0,0 +1,353 @@ +from tile_methods.abstractdiffusion import AbstractDiffusion +from tile_utils.utils import * +import torch.nn.functional as F +import random +from copy import deepcopy +import inspect +from modules import sd_samplers_common + + +class DemoFusion(AbstractDiffusion): + """ + DemoFusion Implementation + https://arxiv.org/abs/2311.16973 + """ + + def __init__(self, p:Processing, *args, **kwargs): + super().__init__(p, *args, **kwargs) + assert p.sampler_name != 'UniPC', 'Demofusion is not compatible with UniPC!' + + + def hook(self): + steps, self.t_enc = sd_samplers_common.setup_img2img_steps(self.p, None) + + self.sampler.model_wrap_cfg.forward_ori = self.sampler.model_wrap_cfg.forward + self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward + self.sampler.model_wrap_cfg.forward = self.forward_one_step + if self.is_kdiff: + self.sampler: KDiffusionSampler + self.sampler.model_wrap_cfg: CFGDenoiserKDiffusion + self.sampler.model_wrap_cfg.inner_model: Union[CompVisDenoiser, CompVisVDenoiser] + else: + self.sampler: CompVisSampler + self.sampler.model_wrap_cfg: CFGDenoiserTimesteps + self.sampler.model_wrap_cfg.inner_model: Union[CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser] + self.timesteps = self.sampler.get_timesteps(self.p, steps) + + @staticmethod + def unhook(): + if hasattr(shared.sd_model, 'apply_model_ori'): + shared.sd_model.apply_model = shared.sd_model.apply_model_ori + del shared.sd_model.apply_model_ori + + def reset_buffer(self, x_in:Tensor): + super().reset_buffer(x_in) + + + + def repeat_tensor(self, x:Tensor, n:int) -> Tensor: + ''' repeat the tensor on it's first dim ''' + if n == 1: return x + B = x.shape[0] + r_dims = len(x.shape) - 1 + if B == 1: # batch_size = 1 (not `tile_batch_size`) + shape = [n] + [-1] * r_dims # [N, -1, ...] + return x.expand(shape) # `expand` is much lighter than `tile` + else: + shape = [n] + [1] * r_dims # [N, 1, ...] + return x.repeat(shape) + + def repeat_cond_dict(self, cond_in:CondDict, bboxes,mode) -> CondDict: + ''' repeat all tensors in cond_dict on it's first dim (for a batch of tiles), returns a new object ''' + # n_repeat + n_rep = len(bboxes) + # txt cond + tcond = self.get_tcond(cond_in) # [B=1, L, D] => [B*N, L, D] + tcond = self.repeat_tensor(tcond, n_rep) + # img cond + icond = self.get_icond(cond_in) + if icond.shape[2:] == (self.h, self.w): # img2img, [B=1, C, H, W] + if mode == 0: + if self.p.random_jitter: + jitter_range = self.jitter_range + icond = F.pad(icond,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) + icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0) + else: + icond = torch.cat([icond[:,:,bbox[1]::self.p.current_scale_num,bbox[0]::self.p.current_scale_num] for bbox in bboxes], dim=0) + else: # txt2img, [B=1, C=5, H=1, W=1] + icond = self.repeat_tensor(icond, n_rep) + + # vec cond (SDXL) + vcond = self.get_vcond(cond_in) # [B=1, D] + if vcond is not None: + vcond = self.repeat_tensor(vcond, n_rep) # [B*N, D] + return self.make_cond_dict(cond_in, tcond, icond, vcond) + + + def global_split_bboxes(self): + cols = self.p.current_scale_num + rows = cols + + bbox_list = [] + for row in range(rows): + y = row + for col in range(cols): + x = col + bbox = (x, y) + bbox_list.append(bbox) + + return bbox_list+bbox_list if self.p.mixture else bbox_list + + def split_bboxes_jitter(self,w_l:int, h_l:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]: + cols = math.ceil((w_l - overlap) / (tile_w - overlap)) + rows = math.ceil((h_l - overlap) / (tile_h - overlap)) + if rows==0: + rows=1 + if cols == 0: + cols=1 + dx = (w_l - tile_w) / (cols - 1) if cols > 1 else 0 + dy = (h_l - tile_h) / (rows - 1) if rows > 1 else 0 + bbox_list: List[BBox] = [] + self.jitter_range = 0 + for row in range(rows): + for col in range(cols): + h = min(int(row * dy), h_l - tile_h) + w = min(int(col * dx), w_l - tile_w) + if self.p.random_jitter: + self.jitter_range = min(max((min(self.w, self.h)-self.stride)//4,0),min(int(self.window_size/2),int(self.overlap/2))) + jitter_range = self.jitter_range + w_jitter = 0 + h_jitter = 0 + if (w != 0) and (w+tile_w != w_l): + w_jitter = random.randint(-jitter_range, jitter_range) + elif (w == 0) and (w + tile_w != w_l): + w_jitter = random.randint(-jitter_range, 0) + elif (w != 0) and (w + tile_w == w_l): + w_jitter = random.randint(0, jitter_range) + if (h != 0) and (h + tile_h != h_l): + h_jitter = random.randint(-jitter_range, jitter_range) + elif (h == 0) and (h + tile_h != h_l): + h_jitter = random.randint(-jitter_range, 0) + elif (h != 0) and (h + tile_h == h_l): + h_jitter = random.randint(0, jitter_range) + h +=(h_jitter + jitter_range) + w += (w_jitter + jitter_range) + + bbox = BBox(w, h, tile_w, tile_h) + bbox_list.append(bbox) + return bbox_list, None + + @grid_bbox + def get_views(self, overlap:int, tile_bs:int,tile_bs_g:int): + self.enable_grid_bbox = True + self.tile_w = self.window_size + self.tile_h = self.window_size + + self.overlap = max(0, min(overlap, self.window_size - 4)) + + self.stride = max(4,self.window_size - self.overlap) + + # split the latent into overlapped tiles, then batching + # weights basically indicate how many times a pixel is painted + bboxes, _ = self.split_bboxes_jitter(self.w, self.h, self.tile_w, self.tile_h, self.overlap, self.get_tile_weights()) + self.num_tiles = len(bboxes) + self.num_batches = math.ceil(self.num_tiles / tile_bs) + self.tile_bs = math.ceil(len(bboxes) / self.num_batches) # optimal_batch_size + self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)] + + global_bboxes = self.global_split_bboxes() + self.global_num_tiles = len(global_bboxes) + self.global_num_batches = math.ceil(self.global_num_tiles / tile_bs_g) + self.global_tile_bs = math.ceil(len(global_bboxes) / self.global_num_batches) + self.global_batched_bboxes = [global_bboxes[i*self.global_tile_bs:(i+1)*self.global_tile_bs] for i in range(self.global_num_batches)] + + def gaussian_kernel(self,kernel_size=3, sigma=1.0, channels=3): + x_coord = torch.arange(kernel_size, device=devices.device) + gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] + kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) + + return kernel + + def gaussian_filter(self,latents, kernel_size=3, sigma=1.0): + channels = latents.shape[1] + kernel = self.gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) + blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) + + return blurred_latents + + + + ''' ↓↓↓ kernel hijacks ↓↓↓ ''' + @torch.no_grad() + @keep_signature + def forward_one_step(self, x_in, sigma, **kwarg): + if self.is_kdiff: + x_noisy = self.p.x + self.p.noise * sigma[0] + else: + alphas_cumprod = self.p.sd_model.alphas_cumprod + sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) + sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) + x_noisy = self.p.x*sqrt_alpha_cumprod + self.p.noise * sqrt_one_minus_alpha_cumprod + + self.cosine_factor = 0.5 * (1 + torch.cos(torch.pi *torch.tensor(((self.p.current_step + 1) / (self.t_enc+1))))) + + c1 = self.cosine_factor ** self.p.cosine_scale_1 + + x_in = x_in*(1 - c1) + x_noisy * c1 + + if self.p.random_jitter: + jitter_range = self.jitter_range + else: + jitter_range = 0 + x_in_ = F.pad(x_in,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) + _,_,H,W = x_in.shape + + self.sampler.model_wrap_cfg.inner_model.forward = self.sample_one_step + self.repeat_3 = False + + x_out = self.sampler.model_wrap_cfg.forward_ori(x_in_,sigma, **kwarg) + self.sampler.model_wrap_cfg.inner_model.forward = self.sampler_forward + x_out = x_out[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W] + + return x_out + + + @torch.no_grad() + @keep_signature + def sample_one_step(self, x_in, sigma, cond): + assert LatentDiffusion.apply_model + def repeat_func_1(x_tile:Tensor, bboxes,mode=0) -> Tensor: + sigma_tile = self.repeat_tensor(sigma, len(bboxes)) + cond_tile = self.repeat_cond_dict(cond, bboxes,mode) + return self.sampler_forward(x_tile, sigma_tile, cond=cond_tile) + + def repeat_func_2(x_tile:Tensor, bboxes,mode=0) -> Tuple[Tensor, Tensor]: + n_rep = len(bboxes) + ts_tile = self.repeat_tensor(sigma, n_rep) + if isinstance(cond, dict): # FIXME: when will enter this branch? + cond_tile = self.repeat_cond_dict(cond, bboxes,mode) + else: + cond_tile = self.repeat_tensor(cond, n_rep) + return self.sampler_forward(x_tile, ts_tile, cond=cond_tile) + + def repeat_func_3(x_tile:Tensor, bboxes,mode=0): + sigma_in_tile = sigma.repeat(len(bboxes)) + cond_out = self.repeat_cond_dict(cond, bboxes,mode) + x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=cond_out) + return x_tile_out + + if self.repeat_3: + repeat_func = repeat_func_3 + self.repeat_3 = False + elif self.is_kdiff: + repeat_func = repeat_func_1 + else: + repeat_func = repeat_func_2 + N,_,_,_ = x_in.shape + + + self.x_buffer = torch.zeros_like(x_in) + self.weights = torch.zeros_like(x_in) + + for batch_id, bboxes in enumerate(self.batched_bboxes): + if state.interrupted: return x_in + x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) + x_tile_out = repeat_func(x_tile, bboxes) + # de-batching + for i, bbox in enumerate(bboxes): + self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] + self.weights[bbox.slicer] += 1 + self.weights = torch.where(self.weights == 0, torch.tensor(1), self.weights) #Prevent NaN from appearing in random_jitter mode + + x_local = self.x_buffer/self.weights + + self.x_buffer = torch.zeros_like(self.x_buffer) + self.weights = torch.zeros_like(self.weights) + + std_, mean_ = x_in.std(), x_in.mean() + c3 = 0.99 * self.cosine_factor ** self.p.cosine_scale_3 + 1e-2 + if self.p.gaussian_filter: + x_in_g = self.gaussian_filter(x_in, kernel_size=(2*self.p.current_scale_num-1), sigma=self.sig*c3) + x_in_g = (x_in_g - x_in_g.mean()) / x_in_g.std() * std_ + mean_ + + if not hasattr(self.p.sd_model, 'apply_model_ori'): + self.p.sd_model.apply_model_ori = self.p.sd_model.apply_model + self.p.sd_model.apply_model = self.apply_model_hijack + x_global = torch.zeros_like(x_local) + jitter_range = self.jitter_range + end = x_global.shape[3]-jitter_range + + current_num = 0 + if self.p.mixture: + for batch_id, bboxes in enumerate(self.global_batched_bboxes): + current_num += len(bboxes) + if current_num > (self.global_num_tiles//2) and (current_num-self.global_tile_bs) < (self.global_num_tiles//2): + res = len(bboxes) - (current_num - self.global_num_tiles//2) + x_in_i = torch.cat([x_in[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] if idx
+
+
+
+
+
+
+
+
+
+
+
+
+
+Separator:{SEP}
+
+
+
+
+
+
+
+
| Feature | Example | Support |
|---|---|---|
| Control Net | OpenPose | Yes |
| Wildcards | __colors__ | Yes |
| Single LoRA | Style | Yes |
| Multi-LoRA | Characters | Limited |
| Prompt Scheduling | [from:to:steps] | No |
Disable for Negative prompt option. Default is True.true.false.Lerp.Target tokens. Default is _ (underbar).
+
+
+Off | On
+
+
+
+Off | On
+
+
+
+
+
+
+
+
+Off | On
+
changeable blocks : BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11
') + xyzsetting = gr.Radio(label = "Active",choices = ["Disable","XYZ plot","Effective Block Analyzer"], value ="Disable",type = "index") + with gr.Row(visible = False) as esets: + diffcol = gr.Radio(label = "diff image color",choices = ["black","white"], value ="black",type = "value",interactive =True) + revxy = gr.Checkbox(value = False,label="change X-Y",interactive =True,elem_id="lbw_changexy") + thresh = gr.Textbox(label="difference threshold",lines=1,value="20",interactive =True,elem_id="diff_thr") + xtype = gr.Dropdown(label="X Types", choices=[x for x in ATYPES], value=ATYPES [2],interactive =True,elem_id="lbw_xtype") + xmen = gr.Textbox(label="X Values",lines=1,value="0,0.25,0.5,0.75,1",interactive =True,elem_id="lbw_xmen") + ytype = gr.Dropdown(label="Y Types", choices=[y for y in ATYPES], value=ATYPES [1],interactive =True,elem_id="lbw_ytype") + ymen = gr.Textbox(label="Y Values" ,lines=1,value="IN05-OUT05",interactive =True,elem_id="lbw_ymen") + ztype = gr.Dropdown(label="Z type", choices=[z for z in ATYPES], value=ATYPES[0],interactive =True,elem_id="lbw_ztype") + zmen = gr.Textbox(label="Z values",lines=1,value="",interactive =True,elem_id="lbw_zmen") + + exmen = gr.Textbox(label="Range",lines=1,value="0.5,1",interactive =True,elem_id="lbw_exmen",visible = False) + eymen = gr.Textbox(label="Blocks (12ALL,17ALL,20ALL,26ALL also can be used)" ,lines=1,value="BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11",interactive =True,elem_id="lbw_eymen",visible = False) + ecount = gr.Number(value=1, label="number of seed", interactive=True, visible = True) + + with gr.Accordion("Weights setting",open = True): + with gr.Row(): + reloadtext = gr.Button(value="Reload Presets",variant='primary',elem_id="lbw_reload") + reloadtags = gr.Button(value="Reload Tags",variant='primary',elem_id="lbw_reload") + savetext = gr.Button(value="Save Presets",variant='primary',elem_id="lbw_savetext") + openeditor = gr.Button(value="Open TextEditor",variant='primary',elem_id="lbw_openeditor") + lbw_loraratios = gr.TextArea(label="",value=lbwpresets,visible =True,interactive = True,elem_id="lbw_ratiospreset") + + with gr.Accordion("Elemental",open = False): + with gr.Row(): + e_reloadtext = gr.Button(value="Reload Presets",variant='primary',elem_id="lbw_reload") + e_savetext = gr.Button(value="Save Presets",variant='primary',elem_id="lbw_savetext") + e_openeditor = gr.Button(value="Open TextEditor",variant='primary',elem_id="lbw_openeditor") + elemsets = gr.Checkbox(value = False,label="print change",interactive =True,elem_id="lbw_print_change") + elemental = gr.TextArea(label="Identifer:BlockID:Elements:Ratio,...,separated by empty line ",value = elempresets,interactive =True,elem_id="element") + + d_true = gr.Checkbox(value = True,visible = False) + d_false = gr.Checkbox(value = False,visible = False) + + with gr.Accordion("Make Weights",open = False): + with gr.Row(): + m_text = gr.Textbox(value="",label="Weights") + with gr.Row(): + m_add = gr.Button(value="Add to presets",size="sm",variant='primary') + m_add_save = gr.Button(value="Add to presets and Save",size="sm",variant='primary') + m_name = gr.Textbox(value="",label="Identifier") + with gr.Row(): + m_type = gr.Radio(label="Weights type",choices=["17(1.X/2.X)", "26(1.X/2.X full)", "12(XL)","20(XL full)"], value="17(1.X/2.X)") + with gr.Row(): + m_set_0 = gr.Button(value="Set All 0",variant='primary') + m_set_1 = gr.Button(value="Set All 1",variant='primary') + m_custom = gr.Button(value="Set custom",variant='primary') + m_custom_v = gr.Slider(show_label=False, minimum=-1.0, maximum=1, step=0.1, value=0, interactive=True) + with gr.Row(): + with gr.Column(scale=1, min_width=100): + gr.Slider(visible=False) + with gr.Column(scale=2, min_width=200): + base = gr.Slider(label="BASE", minimum=-1, maximum=1, step=0.1, value=0.0) + with gr.Column(scale=1, min_width=100): + gr.Slider(visible=False) + with gr.Row(): + with gr.Column(scale=2, min_width=200): + ins = [gr.Slider(label=block, minimum=-1.0, maximum=1, step=0.1, value=0, interactive=True) for block in BLOCKID26[1:13]] + with gr.Column(scale=2, min_width=200): + outs = [gr.Slider(label=block, minimum=-1.0, maximum=1, step=0.1, value=0, interactive=True) for block in reversed(BLOCKID26[14:])] + with gr.Row(): + with gr.Column(scale=1, min_width=100): + gr.Slider(visible=False) + with gr.Column(scale=2, min_width=200): + m00 = gr.Slider(label="M00", minimum=-1, maximum=1, step=0.1, value=0.0) + with gr.Column(scale=1, min_width=100): + gr.Slider(visible=False) + + blocks = [base] + ins + [m00] + outs[::-1] + for block in blocks: + if block.label not in BLOCKID17: + block.visible = False + + m_set_0.click(fn=lambda x:[0]*26 + [",".join(["0"]*int(x[:2]))],inputs=[m_type],outputs=blocks + [m_text]) + m_set_1.click(fn=lambda x:[1]*26 + [",".join(["1"]*int(x[:2]))],inputs=[m_type],outputs=blocks + [m_text]) + m_custom.click(fn=lambda x,y:[x]*26 + [",".join([str(x)]*int(y[:2]))],inputs=[m_custom_v,m_type],outputs=blocks + [m_text]) + + def addweights(weights, id, presets, save = False): + if id == "":id = "NONAME" + lines = presets.strip().split("\n") + id_found = False + for i, line in enumerate(lines): + if line.startswith("#"): + continue + if line.split(":")[0] == id: + lines[i] = f"{id}:{weights}" + id_found = True + break + if not id_found: + lines.append(f"{id}:{weights}") + + if save: + with open(extpath,mode = 'w',encoding="utf-8") as f: + f.write("\n".join(lines)) + + return "\n".join(lines) + + def changetheblocks(sdver,*blocks): + sdver = int(sdver[:2]) + output = [] + targ_blocks = BLOCKIDS[BLOCKNUMS.index(sdver)] + for i, block in enumerate(BLOCKID26): + if block in targ_blocks: + output.append(str(blocks[i])) + return [",".join(output)] + [gr.update(visible = True if block in targ_blocks else False) for block in BLOCKID26] + + m_add.click(fn=addweights, inputs=[m_text,m_name,lbw_loraratios],outputs=[lbw_loraratios]) + m_add_save.click(fn=addweights, inputs=[m_text,m_name,lbw_loraratios, d_true],outputs=[lbw_loraratios]) + m_type.change(fn=changetheblocks, inputs=[m_type] + blocks,outputs=[m_text] + blocks) + + d_true = gr.Checkbox(value = True,visible = False) + d_false = gr.Checkbox(value = False,visible = False) + + lbw_useblocks.change(fn=lambda x:gr.update(label = f"LoRA Block Weight : {'Active' if x else 'Not Active'}"),inputs=lbw_useblocks, outputs=[acc]) + + def makeweights(sdver, *blocks): + sdver = int(sdver[:2]) + output = [] + targ_blocks = BLOCKIDS[BLOCKNUMS.index(sdver)] + for i, block in enumerate(BLOCKID26): + if block in targ_blocks: + output.append(str(blocks[i])) + return ",".join(output) + + changes = [b.release(fn=makeweights,inputs=[m_type] + blocks,outputs=[m_text]) for b in blocks] + + import subprocess + def openeditors(b): + path = extpath if b else extpathe + subprocess.Popen(['start', path], shell=True) + + def reloadpresets(isweight): + if isweight: + try: + with open(extpath,encoding="utf-8") as f: + return f.read() + except OSError as e: + pass + else: + try: + with open(extpath,encoding="utf-8") as f: + return f.read() + except OSError as e: + pass + + def tagdicter(presets): + presets=presets.splitlines() + wdict={} + for l in presets: + if checkloadcond(l) : continue + w=[] + if ":" in l : + key = l.split(":",1)[0] + w = l.split(":",1)[1] + if any(len([w for w in w.split(",")]) == x for x in BLOCKNUMS): + wdict[key.strip()]=w + return ",".join(list(wdict.keys())) + + def savepresets(text,isweight): + if isweight: + with open(extpath,mode = 'w',encoding="utf-8") as f: + f.write(text) + else: + with open(extpathe,mode = 'w',encoding="utf-8") as f: + f.write(text) + + reloadtext.click(fn=reloadpresets,inputs=[d_true],outputs=[lbw_loraratios]) + reloadtags.click(fn=tagdicter,inputs=[lbw_loraratios],outputs=[bw_ratiotags]) + savetext.click(fn=savepresets,inputs=[lbw_loraratios,d_true],outputs=[]) + openeditor.click(fn=openeditors,inputs=[d_true],outputs=[]) + + e_reloadtext.click(fn=reloadpresets,inputs=[d_false],outputs=[elemental]) + e_savetext.click(fn=savepresets,inputs=[elemental,d_false],outputs=[]) + e_openeditor.click(fn=openeditors,inputs=[d_false],outputs=[]) + + def urawaza(active): + if active > 0: + register() + scripts.scripts_txt2img.run = newrun + scripts.scripts_img2img.run = newrun + if active == 1:return [*[gr.update(visible = True) for x in range(6)],*[gr.update(visible = False) for x in range(4)]] + else:return [*[gr.update(visible = False) for x in range(6)],*[gr.update(visible = True) for x in range(4)]] + else: + scripts.scripts_txt2img.run = runorigin + scripts.scripts_img2img.run = runorigini + return [*[gr.update(visible = True) for x in range(6)],*[gr.update(visible = False) for x in range(4)]] + + xyzsetting.change(fn=urawaza,inputs=[xyzsetting],outputs =[xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,esets]) + + return lbw_loraratios,lbw_useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets,debug + + def process(self, p, loraratios,useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets,debug): + #print("self =",self,"p =",p,"presets =",loraratios,"useblocks =",useblocks,"xyzsettings =",xyzsetting,"xtype =",xtype,"xmen =",xmen,"ytype =",ytype,"ymen =",ymen,"ztype =",ztype,"zmen =",zmen) + #Note that this does not use the default arg syntax because the default args are supposed to be at the end of the function + if(loraratios == None): + loraratios = DEF_WEIGHT_PRESET + if(useblocks == None): + useblocks = True + + lorachecker(self) + self.log["enable LBW"] = useblocks + self.log["registerd"] = registerd + + if useblocks: + self.active = True + loraratios=loraratios.splitlines() + elemental = elemental.split("\n\n") if elemental is not None else [] + lratios={} + elementals={} + for l in loraratios: + if checkloadcond(l) : continue + l0=l.split(":",1)[0] + lratios[l0.strip()]=l.split(":",1)[1] + for e in elemental: + if ":" not in e: continue + e0=e.split(":",1)[0] + elementals[e0.strip()]=e.split(":",1)[1] + if elemsets : print(xyelem) + if xyzsetting and "XYZ" in p.prompt: + lratios["XYZ"] = lxyz + lratios["ZYX"] = lzyx + if xyelem != "": + if "XYZ" in elementals.keys(): + elementals["XYZ"] = elementals["XYZ"] + ","+ xyelem + else: + elementals["XYZ"] = xyelem + self.lratios = lratios + self.elementals = elementals + global princ + princ = elemsets + + if not hasattr(self,"lbt_dr_callbacks"): + self.lbt_dr_callbacks = on_cfg_denoiser(self.denoiser_callback) + + def denoiser_callback(self, params: CFGDenoiserParams): + def setparams(self, key, te, u ,sets): + for dicts in [self.lora,self.lycoris,self.networks]: + for lora in dicts: + if lora.name.split("_in_LBW_")[0] == key: + lora.te_multiplier = te + lora.unet_multiplier = u + sets.append(key) + + if forge and self.active: + if params.sampling_step in self.startsf: + shared.sd_model.forge_objects.unet.unpatch_model(device_to=devices.device) + for key, vals in shared.sd_model.forge_objects.unet.patches.items(): + n_vals = [] + lvals = [val for val in vals if val[1][0] in LORAS] + for s, v, m, l, e in zip(self.startsf, lvals, self.uf, self.lf, self.ef): + if s is not None and s == params.sampling_step: + ratio, errormodules = ratiodealer(key.replace(".","_"), l, e) + n_vals.append((ratio * m, *v[1:])) + else: + n_vals.append(v) + shared.sd_model.forge_objects.unet.patches[key] = n_vals + shared.sd_model.forge_objects.unet.patch_model() + + if params.sampling_step in self.stopsf: + shared.sd_model.forge_objects.unet.unpatch_model(device_to=devices.device) + for key, vals in shared.sd_model.forge_objects.unet.patches.items(): + n_vals = [] + lvals = [val for val in vals if val[1][0] in LORAS] + for s, v, m, l, e in zip(self.stopsf, lvals, self.uf, self.lf, self.ef): + if s is not None and s == params.sampling_step: + n_vals.append((0, *v[1:])) + else: + n_vals.append(v) + shared.sd_model.forge_objects.unet.patches[key] = n_vals + shared.sd_model.forge_objects.unet.patch_model() + + elif self.active: + if self.starts and params.sampling_step == 0: + for key, step_te_u in self.starts.items(): + setparams(self, key, 0, 0, []) + #print("\nstart 0", self, key, 0, 0, []) + + if self.starts: + sets = [] + for key, step_te_u in self.starts.items(): + step, te, u = step_te_u + if params.sampling_step > step - 2: + setparams(self, key, te, u, sets) + #print("\nstart", self, key, u, te, sets) + for key in sets: + if key in self.starts: + del self.starts[key] + + if self.stops: + sets = [] + for key, step in self.stops.items(): + if params.sampling_step > step - 2: + setparams(self, key, 0, 0, sets) + #print("\nstop", self, key, 0, 0, sets) + for key in sets: + if key in self.stops: + del self.stops[key] + + def before_process_batch(self, p, loraratios,useblocks,*args,**kwargs): + if useblocks: + resetmemory() + if not self.isnet: p.disable_extra_networks = False + global prompts + prompts = kwargs["prompts"].copy() + + def process_batch(self, p, loraratios,useblocks,*args,**kwargs): + if useblocks: + if not self.isnet: p.disable_extra_networks = True + + o_prompts = [p.prompt] + for prompt in prompts: + if "
+
+
+
+
+
+
+
+
Please test on small images before actual upscale. Default params require denoise <= 0.6
') + with gr.Row(variant='compact'): + noise_inverse_retouch = gr.Slider(minimum=1, maximum=100, step=0.1, label='Retouch', value=1, elem_id=uid('noise-inverse-retouch')) + noise_inverse_renoise_strength = gr.Slider(minimum=0, maximum=2, step=0.01, label='Renoise strength', value=1, elem_id=uid('noise-inverse-renoise-strength')) + noise_inverse_renoise_kernel = gr.Slider(minimum=2, maximum=512, step=1, label='Renoise kernel size', value=64, elem_id=uid('noise-inverse-renoise-kernel')) + + # The control includes txt2img and img2img, we use t2i and i2i to distinguish them + with gr.Group(elem_id=f'MD-bbox-control-{tab}') as tab_bbox: + with gr.Accordion('Region Prompt Control', open=False): + with gr.Row(variant='compact'): + enable_bbox_control = gr.Checkbox(label='Enable Control', value=False, elem_id=uid('enable-bbox-control')) + draw_background = gr.Checkbox(label='Draw full canvas background', value=False, elem_id=uid('draw-background')) + causal_layers = gr.Checkbox(label='Causalize layers', value=False, visible=False, elem_id='MD-causal-layers') # NOTE: currently not used + + with gr.Row(variant='compact'): + create_button = gr.Button(value="Create txt2img canvas" if not is_img2img else "From img2img", elem_id='MD-create-canvas') + + bbox_controls: List[Component] = [] # control set for each bbox + with gr.Row(variant='compact'): + ref_image = gr.Image(label='Ref image (for conviently locate regions)', image_mode=None, elem_id=f'MD-bbox-ref-{tab}', interactive=True) + if not is_img2img: + # gradio has a serious bug: it cannot accept multiple inputs when you use both js and fn. + # to workaround this, we concat the inputs into a single string and parse it in js + def create_t2i_ref(string): + w, h = [int(x) for x in string.split('x')] + w = max(w, opt_f) + h = max(h, opt_f) + return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255 + create_button.click( + fn=create_t2i_ref, + inputs=overwrite_size, + outputs=ref_image, + _js='onCreateT2IRefClick', + show_progress=False) + else: + create_button.click(fn=None, outputs=ref_image, _js='onCreateI2IRefClick', show_progress=False) + + with gr.Row(variant='compact'): + cfg_name = gr.Textbox(label='Custom Config File', value='config.json', elem_id=uid('cfg-name')) + cfg_dump = gr.Button(value='💾 Save', variant='tool') + cfg_load = gr.Button(value='⚙️ Load', variant='tool') + + with gr.Row(variant='compact'): + cfg_tip = gr.HTML(value='', visible=False) + + for i in range(BBOX_MAX_NUM): + # Only when displaying & png generate info we use index i+1, in other cases we use i + with gr.Accordion(f'Region {i+1}', open=False, elem_id=f'MD-accordion-{tab}-{i}'): + with gr.Row(variant='compact'): + e = gr.Checkbox(label=f'Enable Region {i+1}', value=False, elem_id=f'MD-bbox-{tab}-{i}-enable') + e.change(fn=None, inputs=e, outputs=e, _js=f'e => onBoxEnableClick({is_t2i}, {i}, e)', show_progress=False) + + blend_mode = gr.Dropdown(label='Type', choices=[e.value for e in BlendMode], value=BlendMode.BACKGROUND.value, elem_id=f'MD-{tab}-{i}-blend-mode') + feather_ratio = gr.Slider(label='Feather', value=0.2, minimum=0, maximum=1, step=0.05, visible=False, elem_id=f'MD-{tab}-{i}-feather') + + blend_mode.change(fn=lambda x: gr_show(x==BlendMode.FOREGROUND.value), inputs=blend_mode, outputs=feather_ratio, show_progress=False) + + with gr.Row(variant='compact'): + x = gr.Slider(label='x', value=0.4, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-x') + y = gr.Slider(label='y', value=0.4, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-y') + + with gr.Row(variant='compact'): + w = gr.Slider(label='w', value=0.2, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-w') + h = gr.Slider(label='h', value=0.2, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-h') + + x.change(fn=None, inputs=x, outputs=x, _js=f'v => onBoxChange({is_t2i}, {i}, "x", v)', show_progress=False) + y.change(fn=None, inputs=y, outputs=y, _js=f'v => onBoxChange({is_t2i}, {i}, "y", v)', show_progress=False) + w.change(fn=None, inputs=w, outputs=w, _js=f'v => onBoxChange({is_t2i}, {i}, "w", v)', show_progress=False) + h.change(fn=None, inputs=h, outputs=h, _js=f'v => onBoxChange({is_t2i}, {i}, "h", v)', show_progress=False) + + prompt = gr.Text(show_label=False, placeholder=f'Prompt, will append to your {tab} prompt', max_lines=2, elem_id=f'MD-{tab}-{i}-prompt') + neg_prompt = gr.Text(show_label=False, placeholder='Negative Prompt, will also be appended', max_lines=1, elem_id=f'MD-{tab}-{i}-neg-prompt') + with gr.Row(variant='compact'): + seed = gr.Number(label='Seed', value=-1, visible=True, elem_id=f'MD-{tab}-{i}-seed') + random_seed = gr.Button(value='🎲', variant='tool', elem_id=f'MD-{tab}-{i}-random_seed') + reuse_seed = gr.Button(value='♻️', variant='tool', elem_id=f'MD-{tab}-{i}-reuse_seed') + random_seed.click(fn=lambda: -1, outputs=seed, show_progress=False) + reuse_seed.click(fn=None, inputs=seed, outputs=seed, _js=f'e => getSeedInfo({is_t2i}, {i+1}, e)', show_progress=False) + + control = [e, x, y, w, h, prompt, neg_prompt, blend_mode, feather_ratio, seed] + assert len(control) == NUM_BBOX_PARAMS + bbox_controls.extend(control) + + # NOTE: dynamically hard coded!! + load_regions_js = ''' + function onBoxChangeAll(ref_image, cfg_name, ...args) { + const is_t2i = %s; + const n_bbox = %d; + const n_ctrl = %d; + for (let i=0; iPlease test on small images before actual upscale. Default params require denoise <= 0.6
') + with gr.Row(variant='compact'): + noise_inverse_retouch = gr.Slider(minimum=1, maximum=100, step=0.1, label='Retouch', value=1, elem_id=uid('noise-inverse-retouch')) + noise_inverse_renoise_strength = gr.Slider(minimum=0, maximum=2, step=0.01, label='Renoise strength', value=1, elem_id=uid('noise-inverse-renoise-strength')) + noise_inverse_renoise_kernel = gr.Slider(minimum=2, maximum=512, step=1, label='Renoise kernel size', value=64, elem_id=uid('noise-inverse-renoise-kernel')) + + # The control includes txt2img and img2img, we use t2i and i2i to distinguish them + + return [ + enabled, method, + keep_input_size, + window_size, overlap, batch_size, + scale_factor, + noise_inverse, noise_inverse_steps, noise_inverse_retouch, noise_inverse_renoise_strength, noise_inverse_renoise_kernel, + control_tensor_cpu, + random_jitter, + c1,c2,c3,gaussian_filter,strength,sigma,batch_size_g,mixture_mode + ] + + + def process(self, p: Processing, + enabled: bool, method: str, + keep_input_size: bool, + window_size:int, overlap: int, tile_batch_size: int, + scale_factor: float, + noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch: float, noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int, + control_tensor_cpu: bool, + random_jitter:bool, + c1,c2,c3,gaussian_filter,strength,sigma,batch_size_g,mixture_mode + ): + + # unhijack & unhook, in case it broke at last time + self.reset() + p.mixture = mixture_mode + if not mixture_mode: + sigma = sigma/2 + if not enabled: return + + ''' upscale ''' + # store canvas size settings + if hasattr(p, "init_images"): + p.init_images_original_md = [img.copy() for img in p.init_images] + p.width_original_md = p.width + p.height_original_md = p.height + p.current_scale_num = 1 + p.gaussian_filter = gaussian_filter + p.scale_factor = int(scale_factor) + + is_img2img = hasattr(p, "init_images") and len(p.init_images) > 0 + if is_img2img: + init_img = p.init_images[0] + init_img = images.flatten(init_img, opts.img2img_background_color) + image = init_img + if keep_input_size: + p.width = image.width + p.height = image.height + p.width_original_md = p.width + p.height_original_md = p.height + else: #XXX:To adapt to noise inversion, we do not multiply the scale factor here + p.width = p.width_original_md + p.height = p.height_original_md + else: # txt2img + p.width = p.width_original_md + p.height = p.height_original_md + + if 'png info': + info = {} + p.extra_generation_params["Tiled Diffusion"] = info + + info['Method'] = method + info['Window Size'] = window_size + info['Tile Overlap'] = overlap + info['Tile batch size'] = tile_batch_size + info["Global batch size"] = batch_size_g + + if is_img2img: + info['Upscale factor'] = scale_factor + if keep_input_size: + info['Keep input size'] = keep_input_size + if noise_inverse: + info['NoiseInv'] = noise_inverse + info['NoiseInv Steps'] = noise_inverse_steps + info['NoiseInv Retouch'] = noise_inverse_retouch + info['NoiseInv Renoise strength'] = noise_inverse_renoise_strength + info['NoiseInv Kernel size'] = noise_inverse_renoise_kernel + + ''' ControlNet hackin ''' + try: + from scripts.cldm import ControlNet + + for script in p.scripts.scripts + p.scripts.alwayson_scripts: + if hasattr(script, "latest_network") and script.title().lower() == "controlnet": + self.controlnet_script = script + print("[Demo Fusion] ControlNet found, support is enabled.") + break + except ImportError: + pass + + ''' StableSR hackin ''' + for script in p.scripts.scripts: + if hasattr(script, "stablesr_model") and script.title().lower() == "stablesr": + if script.stablesr_model is not None: + self.stablesr_script = script + print("[Demo Fusion] StableSR found, support is enabled.") + break + + ''' hijack inner APIs, see unhijack in reset() ''' + Script.create_sampler_original_md = sd_samplers.create_sampler + + sd_samplers.create_sampler = lambda name, model: self.create_sampler_hijack( + name, model, p, Method_2(method), control_tensor_cpu,window_size, noise_inverse, noise_inverse_steps, noise_inverse_retouch, + noise_inverse_renoise_strength, noise_inverse_renoise_kernel, overlap, tile_batch_size,random_jitter,batch_size_g + ) + + + p.sample = lambda conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts: self.sample_hijack( + conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts,p, is_img2img, + window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma,batch_size_g) + + processing.create_infotext_ori = processing.create_infotext + + p.width_list = [p.height] + p.height_list = [p.height] + + processing.create_infotext = create_infotext_hijack + ## end + + + def postprocess_batch(self, p: Processing, enabled, *args, **kwargs): + if not enabled: return + + if self.delegate is not None: self.delegate.reset_controlnet_tensors() + + def postprocess_batch_list(self, p, pp, enabled, *args, **kwargs): + if not enabled: return + for idx,image in enumerate(pp.images): + idx_b = idx//p.batch_size + pp.images[idx] = image[:,:image.shape[1]//(p.scale_factor)*(idx_b+1),:image.shape[2]//(p.scale_factor)*(idx_b+1)] + p.seeds = [item for _ in range(p.scale_factor) for item in p.seeds] + p.prompts = [item for _ in range(p.scale_factor) for item in p.prompts] + p.all_negative_prompts = [item for _ in range(p.scale_factor) for item in p.all_negative_prompts] + p.negative_prompts = [item for _ in range(p.scale_factor) for item in p.negative_prompts] + if p.color_corrections != None: + p.color_corrections = [item for _ in range(p.scale_factor) for item in p.color_corrections] + p.width_list = [item*(idx+1) for idx in range(p.scale_factor) for item in [p.width for _ in range(p.batch_size)]] + p.height_list = [item*(idx+1) for idx in range(p.scale_factor) for item in [p.height for _ in range(p.batch_size)]] + return + + def postprocess(self, p: Processing, processed, enabled, *args): + if not enabled: return + # unhijack & unhook + self.reset() + + # restore canvas size settings + if hasattr(p, 'init_images') and hasattr(p, 'init_images_original_md'): + p.init_images.clear() # NOTE: do NOT change the list object, compatible with shallow copy of XYZ-plot + p.init_images.extend(p.init_images_original_md) + del p.init_images_original_md + p.width = p.width_original_md ; del p.width_original_md + p.height = p.height_original_md ; del p.height_original_md + + # clean up noise inverse latent for folder-based processing + if hasattr(p, 'noise_inverse_latent'): + del p.noise_inverse_latent + + ''' ↓↓↓ inner API hijack ↓↓↓ ''' + @torch.no_grad() + def sample_hijack(self, conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts,p,image_ori,window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma,batch_size_g): + ################################################## Phase Initialization ###################################################### + + if not image_ori: + p.current_step = 0 + p.denoising_strength = strength + # p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) #NOTE:Wrong but very useful. If corrected, please replace with the content with the following lines + # latents = p.rng.next() + + p.sampler = Script.create_sampler_original_md(p.sampler_name, p.sd_model) #scale + x = p.rng.next() + print("### Phase 1 Denoising ###") + latents = p.sampler.sample(p, x, conditioning, unconditional_conditioning, image_conditioning=p.txt2img_image_conditioning(x)) + latents_ = F.pad(latents, (0, latents.shape[3]*(p.scale_factor-1), 0, latents.shape[2]*(p.scale_factor-1))) + res = latents_ + del x + p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) + starting_scale = 2 + else: # img2img + print("### Encoding Real Image ###") + latents = p.init_latent + starting_scale = 1 + + + anchor_mean = latents.mean() + anchor_std = latents.std() + + devices.torch_gc() + + ####################################################### Phase Upscaling ##################################################### + p.cosine_scale_1 = c1 + p.cosine_scale_2 = c2 + p.cosine_scale_3 = c3 + self.delegate.sig = sigma + p.latents = latents + for current_scale_num in range(starting_scale, p.scale_factor+1): + p.current_scale_num = current_scale_num + print("### Phase {} Denoising ###".format(current_scale_num)) + p.current_height = p.height_original_md * current_scale_num + p.current_width = p.width_original_md * current_scale_num + + + p.latents = F.interpolate(p.latents, size=(int(p.current_height / opt_f), int(p.current_width / opt_f)), mode='bicubic') + p.rng = rng.ImageRNG(p.latents.shape[1:], p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) + + + self.delegate.w = int(p.current_width / opt_f) + self.delegate.h = int(p.current_height / opt_f) + self.delegate.get_views(overlap, tile_batch_size,batch_size_g) + + info = ', '.join([ + # f"{method.value} hooked into {name!r} sampler", + f"Tile size: {self.delegate.window_size}", + f"Tile count: {self.delegate.num_tiles}", + f"Batch size: {self.delegate.tile_bs}", + f"Tile batches: {len(self.delegate.batched_bboxes)}", + f"Global batch size: {self.delegate.global_tile_bs}", + f"Global batches: {len(self.delegate.global_batched_bboxes)}", + ]) + + print(info) + + noise = p.rng.next() + if hasattr(p,'initial_noise_multiplier'): + if p.initial_noise_multiplier != 1.0: + p.extra_generation_params["Noise multiplier"] = p.initial_noise_multiplier + noise *= p.initial_noise_multiplier + else: + p.image_conditioning = p.txt2img_image_conditioning(noise) + + p.noise = noise + p.x = p.latents.clone() + p.current_step=0 + + p.latents = p.sampler.sample_img2img(p,p.latents, noise , conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) + if self.flag_noise_inverse: + self.delegate.sampler_raw.sample_img2img = self.delegate.sample_img2img_original + self.flag_noise_inverse = False + + p.latents = (p.latents - p.latents.mean()) / p.latents.std() * anchor_std + anchor_mean + latents_ = F.pad(p.latents, (0, p.latents.shape[3]//current_scale_num*(p.scale_factor-current_scale_num), 0, p.latents.shape[2]//current_scale_num*(p.scale_factor-current_scale_num))) + if current_scale_num==1: + res = latents_ + else: + res = torch.concatenate((res,latents_),axis=0) + + ######################################################################################################################################### + + return res + + @staticmethod + def callback_hijack(self_sampler,d,p): + p.current_step = d['i'] + + if self_sampler.stop_at is not None and p.current_step > self_sampler.stop_at: + raise InterruptedException + + state.sampling_step = p.current_step + shared.total_tqdm.update() + p.current_step += 1 + + + def create_sampler_hijack( + self, name: str, model: LatentDiffusion, p: Processing, method: Method_2, control_tensor_cpu:bool,window_size, noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch:float, + noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int, overlap:int, tile_batch_size:int, random_jitter:bool,batch_size_g:int + ): + if self.delegate is not None: + # samplers are stateless, we reuse it if possible + if self.delegate.sampler_name == name: + # before we reuse the sampler, we refresh the control tensor + # so that we are compatible with ControlNet batch processing + if self.controlnet_script: + self.delegate.prepare_controlnet_tensors(refresh=True) + return self.delegate.sampler_raw + else: + self.reset() + sd_samplers_common.Sampler.callback_ori = sd_samplers_common.Sampler.callback_state + sd_samplers_common.Sampler.callback_state = lambda self_sampler,d:Script.callback_hijack(self_sampler,d,p) + + self.flag_noise_inverse = hasattr(p, "init_images") and len(p.init_images) > 0 and noise_inverse + flag_noise_inverse = self.flag_noise_inverse + if flag_noise_inverse: + print('warn: noise inversion only supports the "Euler" sampler, switch to it sliently...') + name = 'Euler' + p.sampler_name = 'Euler' + if name is None: print('>> name is empty') + if model is None: print('>> model is empty') + sampler = Script.create_sampler_original_md(name, model) + if method ==Method_2.DEMO_FU: delegate_cls = DemoFusion + else: raise NotImplementedError(f"Method {method} not implemented.") + + delegate = delegate_cls(p, sampler) + delegate.window_size = min(min(window_size,p.width//8),p.height//8) + p.random_jitter = random_jitter + + if flag_noise_inverse: + get_cache_callback = self.noise_inverse_get_cache + set_cache_callback = lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, noise_inverse_steps, noise_inverse_retouch) + delegate.init_noise_inverse(noise_inverse_steps, noise_inverse_retouch, get_cache_callback, set_cache_callback, noise_inverse_renoise_strength, noise_inverse_renoise_kernel) + + # delegate.get_views(overlap,tile_batch_size,batch_size_g) + if self.controlnet_script: + delegate.init_controlnet(self.controlnet_script, control_tensor_cpu) + if self.stablesr_script: + delegate.init_stablesr(self.stablesr_script) + + # init everything done, perform sanity check & pre-computations + # hijack the behaviours + delegate.hook() + + self.delegate = delegate + + exts = [ + "ContrlNet" if self.controlnet_script else None, + "StableSR" if self.stablesr_script else None, + ] + ext_info = ', '.join([e for e in exts if e]) + if ext_info: ext_info = f' (ext: {ext_info})' + print(ext_info) + + return delegate.sampler_raw + + def create_random_tensors_hijack( + self, bbox_settings: Dict, region_info: Dict, + shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None, + ): + org_random_tensors = Script.create_random_tensors_original_md(shape, seeds, subseeds, subseed_strength, seed_resize_from_h, seed_resize_from_w, p) + height, width = shape[1], shape[2] + background_noise = torch.zeros_like(org_random_tensors) + background_noise_count = torch.zeros((1, 1, height, width), device=org_random_tensors.device) + foreground_noise = torch.zeros_like(org_random_tensors) + foreground_noise_count = torch.zeros((1, 1, height, width), device=org_random_tensors.device) + + for i, v in bbox_settings.items(): + seed = get_fixed_seed(v.seed) + x, y, w, h = v.x, v.y, v.w, v.h + # convert to pixel + x = int(x * width) + y = int(y * height) + w = math.ceil(w * width) + h = math.ceil(h * height) + # clamp + x = max(0, x) + y = max(0, y) + w = min(width - x, w) + h = min(height - y, h) + # create random tensor + torch.manual_seed(seed) + rand_tensor = torch.randn((1, org_random_tensors.shape[1], h, w), device=devices.cpu) + if BlendMode(v.blend_mode) == BlendMode.BACKGROUND: + background_noise [:, :, y:y+h, x:x+w] += rand_tensor.to(background_noise.device) + background_noise_count[:, :, y:y+h, x:x+w] += 1 + elif BlendMode(v.blend_mode) == BlendMode.FOREGROUND: + foreground_noise [:, :, y:y+h, x:x+w] += rand_tensor.to(foreground_noise.device) + foreground_noise_count[:, :, y:y+h, x:x+w] += 1 + else: + raise NotImplementedError + region_info['Region ' + str(i+1)]['seed'] = seed + + # average + background_noise = torch.where(background_noise_count > 1, background_noise / background_noise_count, background_noise) + foreground_noise = torch.where(foreground_noise_count > 1, foreground_noise / foreground_noise_count, foreground_noise) + # paste two layers to original random tensor + org_random_tensors = torch.where(background_noise_count > 0, background_noise, org_random_tensors) + org_random_tensors = torch.where(foreground_noise_count > 0, foreground_noise, org_random_tensors) + return org_random_tensors + + ''' ↓↓↓ helper methods ↓↓↓ ''' + + def dump_regions(self, cfg_name, *bbox_controls): + if not cfg_name: return gr_value(f'Config file name cannot be empty.', visible=True) + + bbox_settings = build_bbox_settings(bbox_controls) + data = {'bbox_controls': [v._asdict() for v in bbox_settings.values()]} + + if not os.path.exists(CFG_PATH): os.makedirs(CFG_PATH) + fp = os.path.join(CFG_PATH, cfg_name) + with open(fp, 'w', encoding='utf-8') as fh: + json.dump(data, fh, indent=2, ensure_ascii=False) + + return gr_value(f'Config saved to {fp}.', visible=True) + + def load_regions(self, ref_image, cfg_name, *bbox_controls): + if ref_image is None: + return [gr_value(v) for v in bbox_controls] + [gr_value(f'Please create or upload a ref image first.', visible=True)] + fp = os.path.join(CFG_PATH, cfg_name) + if not os.path.exists(fp): + return [gr_value(v) for v in bbox_controls] + [gr_value(f'Config {fp} not found.', visible=True)] + + try: + with open(fp, 'r', encoding='utf-8') as fh: + data = json.load(fh) + except Exception as e: + return [gr_value(v) for v in bbox_controls] + [gr_value(f'Failed to load config {fp}: {e}', visible=True)] + + num_boxes = len(data['bbox_controls']) + data_list = [] + for i in range(BBOX_MAX_NUM): + if i < num_boxes: + for k in BBoxSettings._fields: + if k in data['bbox_controls'][i]: + data_list.append(data['bbox_controls'][i][k]) + else: + data_list.append(None) + else: + data_list.extend(DEFAULT_BBOX_SETTINGS) + + return [gr_value(v) for v in data_list] + [gr_value(f'Config loaded from {fp}.', visible=True)] + + + def noise_inverse_set_cache(self, p: ProcessingImg2Img, x0: Tensor, xt: Tensor, prompts: List[str], steps: int, retouch:float): + self.noise_inverse_cache = NoiseInverseCache(p.sd_model.sd_model_hash, x0, xt, steps, retouch, prompts) + + def noise_inverse_get_cache(self): + return self.noise_inverse_cache + + + def reset(self): + ''' unhijack inner APIs, see hijack in process() ''' + if hasattr(Script, "create_sampler_original_md"): + sd_samplers.create_sampler = Script.create_sampler_original_md + del Script.create_sampler_original_md + if hasattr(Script, "create_random_tensors_original_md"): + processing.create_random_tensors = Script.create_random_tensors_original_md + del Script.create_random_tensors_original_md + if hasattr(sd_samplers_common.Sampler, "callback_ori"): + sd_samplers_common.Sampler.callback_state = sd_samplers_common.Sampler.callback_ori + del sd_samplers_common.Sampler.callback_ori + if hasattr(processing, "create_infotext_ori"): + processing.create_infotext = processing.create_infotext_ori + del processing.create_infotext_ori + DemoFusion.unhook() + self.delegate = None + + def reset_and_gc(self): + self.reset() + self.noise_inverse_cache = None + + import gc; gc.collect() + devices.torch_gc() + + try: + import os + import psutil + mem = psutil.Process(os.getpid()).memory_info() + print(f'[Mem] rss: {mem.rss/2**30:.3f} GB, vms: {mem.vms/2**30:.3f} GB') + from modules.shared import mem_mon as vram_mon + from modules.memmon import MemUsageMonitor + vram_mon: MemUsageMonitor + free, total = vram_mon.cuda_mem_get_info() + print(f'[VRAM] free: {free/2**30:.3f} GB, total: {total/2**30:.3f} GB') + except: + pass diff --git a/extensions/multidiffusion-upscaler-for-automatic1111/scripts/tilevae.py b/extensions/multidiffusion-upscaler-for-automatic1111/scripts/tilevae.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4ed4a2e3b0e0213911cb8d289fb2d36d974d9c --- /dev/null +++ b/extensions/multidiffusion-upscaler-for-automatic1111/scripts/tilevae.py @@ -0,0 +1,758 @@ +''' +# ------------------------------------------------------------------------ +# +# Tiled VAE +# +# Introducing a revolutionary new optimization designed to make +# the VAE work with giant images on limited VRAM! +# Say goodbye to the frustration of OOM and hello to seamless output! +# +# ------------------------------------------------------------------------ +# +# This script is a wild hack that splits the image into tiles, +# encodes each tile separately, and merges the result back together. +# +# Advantages: +# - The VAE can now work with giant images on limited VRAM +# (~10 GB for 8K images!) +# - The merged output is completely seamless without any post-processing. +# +# Drawbacks: +# - NaNs always appear in for 8k images when you use fp16 (half) VAE +# You must use --no-half-vae to disable half VAE for that giant image. +# - The gradient calculation is not compatible with this hack. It +# will break any backward() or torch.autograd.grad() that passes VAE. +# (But you can still use the VAE to generate training data.) +# +# How it works: +# 1. The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder. +# 2. When Fast Mode is disabled: +# 1. The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile. +# 2. When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile. +# 3. After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues. +# 4. A zigzag execution order is used to reduce unnecessary data transfer. +# 3. When Fast Mode is enabled: +# 1. The original input is downsampled and passed to a separate task queue. +# 2. Its group norm parameters are recorded and used by all tiles' task queues. +# 3. Each tile is separately processed without any RAM-VRAM data transfer. +# 4. After all tiles are processed, tiles are written to a result buffer and returned. +# Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode. +# +# Enjoy! +# +# @Author: LI YI @ Nanyang Technological University - Singapore +# @Date: 2023-03-02 +# @License: CC BY-NC-SA 4.0 +# +# Please give me a star if you like this project! +# +# ------------------------------------------------------------------------- +''' + +import gc +import math +from time import time +from tqdm import tqdm + +import torch +import torch.version +import torch.nn.functional as F +import gradio as gr + +import modules.scripts as scripts +import modules.devices as devices +from modules.shared import state, opts +from modules.ui import gr_show +from modules.processing import opt_f +from modules.sd_vae_approx import cheap_approximation +from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock + +from tile_utils.attn import get_attn_func +from tile_utils.typing import Processing + +if hasattr(opts, 'hypertile_enable_unet'): # webui >= 1.7 + from modules.ui_components import InputAccordion +else: + InputAccordion = None + + +def get_rcmd_enc_tsize(): + if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]: + total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20 + if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072 + elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048 + elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536 + else: ENCODER_TILE_SIZE = 960 + else: ENCODER_TILE_SIZE = 512 + return ENCODER_TILE_SIZE + + +def get_rcmd_dec_tsize(): + if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]: + total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20 + if total_memory > 30*1000: DECODER_TILE_SIZE = 256 + elif total_memory > 16*1000: DECODER_TILE_SIZE = 192 + elif total_memory > 12*1000: DECODER_TILE_SIZE = 128 + elif total_memory > 8*1000: DECODER_TILE_SIZE = 96 + else: DECODER_TILE_SIZE = 64 + else: DECODER_TILE_SIZE = 64 + return DECODER_TILE_SIZE + + +def inplace_nonlinearity(x): + # Test: fix for Nans + return F.silu(x, inplace=True) + + +def attn2task(task_queue, net): + attn_forward = get_attn_func() + task_queue.append(('store_res', lambda x: x)) + task_queue.append(('pre_norm', net.norm)) + task_queue.append(('attn', lambda x, net=net: attn_forward(net, x))) + task_queue.append(['add_res', None]) + + +def resblock2task(queue, block): + """ + Turn a ResNetBlock into a sequence of tasks and append to the task queue + + @param queue: the target task queue + @param block: ResNetBlock + + """ + if block.in_channels != block.out_channels: + if block.use_conv_shortcut: + queue.append(('store_res', block.conv_shortcut)) + else: + queue.append(('store_res', block.nin_shortcut)) + else: + queue.append(('store_res', lambda x: x)) + queue.append(('pre_norm', block.norm1)) + queue.append(('silu', inplace_nonlinearity)) + queue.append(('conv1', block.conv1)) + queue.append(('pre_norm', block.norm2)) + queue.append(('silu', inplace_nonlinearity)) + queue.append(('conv2', block.conv2)) + queue.append(['add_res', None]) + + +def build_sampling(task_queue, net, is_decoder): + """ + Build the sampling part of a task queue + @param task_queue: the target task queue + @param net: the network + @param is_decoder: currently building decoder or encoder + """ + if is_decoder: + resblock2task(task_queue, net.mid.block_1) + attn2task(task_queue, net.mid.attn_1) + resblock2task(task_queue, net.mid.block_2) + resolution_iter = reversed(range(net.num_resolutions)) + block_ids = net.num_res_blocks + 1 + condition = 0 + module = net.up + func_name = 'upsample' + else: + resolution_iter = range(net.num_resolutions) + block_ids = net.num_res_blocks + condition = net.num_resolutions - 1 + module = net.down + func_name = 'downsample' + + for i_level in resolution_iter: + for i_block in range(block_ids): + resblock2task(task_queue, module[i_level].block[i_block]) + if i_level != condition: + task_queue.append((func_name, getattr(module[i_level], func_name))) + + if not is_decoder: + resblock2task(task_queue, net.mid.block_1) + attn2task(task_queue, net.mid.attn_1) + resblock2task(task_queue, net.mid.block_2) + + +def build_task_queue(net, is_decoder): + """ + Build a single task queue for the encoder or decoder + @param net: the VAE decoder or encoder network + @param is_decoder: currently building decoder or encoder + @return: the task queue + """ + task_queue = [] + task_queue.append(('conv_in', net.conv_in)) + + # construct the sampling part of the task queue + # because encoder and decoder share the same architecture, we extract the sampling part + build_sampling(task_queue, net, is_decoder) + + if not is_decoder or not net.give_pre_end: + task_queue.append(('pre_norm', net.norm_out)) + task_queue.append(('silu', inplace_nonlinearity)) + task_queue.append(('conv_out', net.conv_out)) + if is_decoder and net.tanh_out: + task_queue.append(('tanh', torch.tanh)) + + return task_queue + + +def clone_task_queue(task_queue): + """ + Clone a task queue + @param task_queue: the task queue to be cloned + @return: the cloned task queue + """ + return [[item for item in task] for task in task_queue] + + +def get_var_mean(input, num_groups, eps=1e-6): + """ + Get mean and var for group norm + """ + b, c = input.size(0), input.size(1) + channel_in_group = int(c/num_groups) + input_reshaped = input.contiguous().view(1, int(b * num_groups), channel_in_group, *input.size()[2:]) + var, mean = torch.var_mean(input_reshaped, dim=[0, 2, 3, 4], unbiased=False) + return var, mean + + +def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6): + """ + Custom group norm with fixed mean and var + + @param input: input tensor + @param num_groups: number of groups. by default, num_groups = 32 + @param mean: mean, must be pre-calculated by get_var_mean + @param var: var, must be pre-calculated by get_var_mean + @param weight: weight, should be fetched from the original group norm + @param bias: bias, should be fetched from the original group norm + @param eps: epsilon, by default, eps = 1e-6 to match the original group norm + + @return: normalized tensor + """ + b, c = input.size(0), input.size(1) + channel_in_group = int(c/num_groups) + input_reshaped = input.contiguous().view( + 1, int(b * num_groups), channel_in_group, *input.size()[2:]) + + out = F.batch_norm(input_reshaped, mean.to(input), var.to(input), weight=None, bias=None, training=False, momentum=0, eps=eps) + out = out.view(b, c, *input.size()[2:]) + + # post affine transform + if weight is not None: + out *= weight.view(1, -1, 1, 1) + if bias is not None: + out += bias.view(1, -1, 1, 1) + return out + + +def crop_valid_region(x, input_bbox, target_bbox, is_decoder): + """ + Crop the valid region from the tile + @param x: input tile + @param input_bbox: original input bounding box + @param target_bbox: output bounding box + @param scale: scale factor + @return: cropped tile + """ + padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox] + margin = [target_bbox[i] - padded_bbox[i] for i in range(4)] + return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]] + + +# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓ + +def perfcount(fn): + def wrapper(*args, **kwargs): + ts = time() + + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats(devices.device) + devices.torch_gc() + gc.collect() + + ret = fn(*args, **kwargs) + + devices.torch_gc() + gc.collect() + if torch.cuda.is_available(): + vram = torch.cuda.max_memory_allocated(devices.device) / 2**20 + print(f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB') + else: + print(f'[Tiled VAE]: Done in {time() - ts:.3f}s') + + return ret + return wrapper + +# ↑↑↑ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↑↑↑ + + +class GroupNormParam: + + def __init__(self): + self.var_list = [] + self.mean_list = [] + self.pixel_list = [] + self.weight = None + self.bias = None + + def add_tile(self, tile, layer): + var, mean = get_var_mean(tile, 32) + # For giant images, the variance can be larger than max float16 + # In this case we create a copy to float32 + if var.dtype == torch.float16 and var.isinf().any(): + fp32_tile = tile.float() + var, mean = get_var_mean(fp32_tile, 32) + # ============= DEBUG: test for infinite ============= + # if torch.isinf(var).any(): + # print('var: ', var) + # ==================================================== + self.var_list.append(var) + self.mean_list.append(mean) + self.pixel_list.append( + tile.shape[2]*tile.shape[3]) + if hasattr(layer, 'weight'): + self.weight = layer.weight + self.bias = layer.bias + else: + self.weight = None + self.bias = None + + def summary(self): + """ + summarize the mean and var and return a function + that apply group norm on each tile + """ + if len(self.var_list) == 0: return None + + var = torch.vstack(self.var_list) + mean = torch.vstack(self.mean_list) + max_value = max(self.pixel_list) + pixels = torch.tensor(self.pixel_list, dtype=torch.float32, device=devices.device) / max_value + sum_pixels = torch.sum(pixels) + pixels = pixels.unsqueeze(1) / sum_pixels + var = torch.sum(var * pixels, dim=0) + mean = torch.sum(mean * pixels, dim=0) + return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias) + + @staticmethod + def from_tile(tile, norm): + """ + create a function from a single tile without summary + """ + var, mean = get_var_mean(tile, 32) + if var.dtype == torch.float16 and var.isinf().any(): + fp32_tile = tile.float() + var, mean = get_var_mean(fp32_tile, 32) + # if it is a macbook, we need to convert back to float16 + if var.device.type == 'mps': + # clamp to avoid overflow + var = torch.clamp(var, 0, 60000) + var = var.half() + mean = mean.half() + if hasattr(norm, 'weight'): + weight = norm.weight + bias = norm.bias + else: + weight = None + bias = None + + def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias): + return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6) + return group_norm_func + + +class VAEHook: + + def __init__(self, net, tile_size, is_decoder:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool, to_gpu:bool=False): + self.net = net # encoder | decoder + self.tile_size = tile_size + self.is_decoder = is_decoder + self.fast_mode = (fast_encoder and not is_decoder) or (fast_decoder and is_decoder) + self.color_fix = color_fix and not is_decoder + self.to_gpu = to_gpu + self.pad = 11 if is_decoder else 32 # FIXME: magic number + + def __call__(self, x): + original_device = next(self.net.parameters()).device + try: + if self.to_gpu: + self.net = self.net.to(devices.get_optimal_device()) + + B, C, H, W = x.shape + if max(H, W) <= self.pad * 2 + self.tile_size: + print("[Tiled VAE]: the input size is tiny and unnecessary to tile.") + return self.net.original_forward(x) + else: + return self.vae_tile_forward(x) + finally: + self.net = self.net.to(original_device) + + def get_best_tile_size(self, lowerbound, upperbound): + """ + Get the best tile size for GPU memory + """ + divider = 32 + while divider >= 2: + remainer = lowerbound % divider + if remainer == 0: + return lowerbound + candidate = lowerbound - remainer + divider + if candidate <= upperbound: + return candidate + divider //= 2 + return lowerbound + + def split_tiles(self, h, w): + """ + Tool function to split the image into tiles + @param h: height of the image + @param w: width of the image + @return: tile_input_bboxes, tile_output_bboxes + """ + tile_input_bboxes, tile_output_bboxes = [], [] + tile_size = self.tile_size + pad = self.pad + num_height_tiles = math.ceil((h - 2 * pad) / tile_size) + num_width_tiles = math.ceil((w - 2 * pad) / tile_size) + # If any of the numbers are 0, we let it be 1 + # This is to deal with long and thin images + num_height_tiles = max(num_height_tiles, 1) + num_width_tiles = max(num_width_tiles, 1) + + # Suggestions from https://github.com/Kahsolt: auto shrink the tile size + real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles) + real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles) + real_tile_height = self.get_best_tile_size(real_tile_height, tile_size) + real_tile_width = self.get_best_tile_size(real_tile_width, tile_size) + + print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' + + f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}') + + for i in range(num_height_tiles): + for j in range(num_width_tiles): + # bbox: [x1, x2, y1, y2] + # the padding is is unnessary for image borders. So we directly start from (32, 32) + input_bbox = [ + pad + j * real_tile_width, + min(pad + (j + 1) * real_tile_width, w), + pad + i * real_tile_height, + min(pad + (i + 1) * real_tile_height, h), + ] + + # if the output bbox is close to the image boundary, we extend it to the image boundary + output_bbox = [ + input_bbox[0] if input_bbox[0] > pad else 0, + input_bbox[1] if input_bbox[1] < w - pad else w, + input_bbox[2] if input_bbox[2] > pad else 0, + input_bbox[3] if input_bbox[3] < h - pad else h, + ] + + # scale to get the final output bbox + output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox] + tile_output_bboxes.append(output_bbox) + + # indistinguishable expand the input bbox by pad pixels + tile_input_bboxes.append([ + max(0, input_bbox[0] - pad), + min(w, input_bbox[1] + pad), + max(0, input_bbox[2] - pad), + min(h, input_bbox[3] + pad), + ]) + + return tile_input_bboxes, tile_output_bboxes + + @torch.no_grad() + def estimate_group_norm(self, z, task_queue, color_fix): + device = z.device + tile = z + last_id = len(task_queue) - 1 + while last_id >= 0 and task_queue[last_id][0] != 'pre_norm': + last_id -= 1 + if last_id <= 0 or task_queue[last_id][0] != 'pre_norm': + raise ValueError('No group norm found in the task queue') + # estimate until the last group norm + for i in range(last_id + 1): + task = task_queue[i] + if task[0] == 'pre_norm': + group_norm_func = GroupNormParam.from_tile(tile, task[1]) + task_queue[i] = ('apply_norm', group_norm_func) + if i == last_id: + return True + tile = group_norm_func(tile) + elif task[0] == 'store_res': + task_id = i + 1 + while task_id < last_id and task_queue[task_id][0] != 'add_res': + task_id += 1 + if task_id >= last_id: + continue + task_queue[task_id][1] = task[1](tile) + elif task[0] == 'add_res': + tile += task[1].to(device) + task[1] = None + elif color_fix and task[0] == 'downsample': + for j in range(i, last_id + 1): + if task_queue[j][0] == 'store_res': + task_queue[j] = ('store_res_cpu', task_queue[j][1]) + return True + else: + tile = task[1](tile) + try: + devices.test_for_nans(tile, "vae") + except: + print(f'Nan detected in fast mode estimation. Fast mode disabled.') + return False + + raise IndexError('Should not reach here') + + @perfcount + @torch.no_grad() + def vae_tile_forward(self, z): + """ + Decode a latent vector z into an image in a tiled manner. + @param z: latent vector + @return: image + """ + device = next(self.net.parameters()).device + dtype = next(self.net.parameters()).dtype + net = self.net + tile_size = self.tile_size + is_decoder = self.is_decoder + + z = z.detach() # detach the input to avoid backprop + + N, height, width = z.shape[0], z.shape[2], z.shape[3] + net.last_z_shape = z.shape + + # Split the input into tiles and build a task queue for each tile + print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}') + + in_bboxes, out_bboxes = self.split_tiles(height, width) + + # Prepare tiles by split the input latents + tiles = [] + for input_bbox in in_bboxes: + tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu() + tiles.append(tile) + + num_tiles = len(tiles) + num_completed = 0 + + # Build task queues + single_task_queue = build_task_queue(net, is_decoder) + if self.fast_mode: + # Fast mode: downsample the input image to the tile size, + # then estimate the group norm parameters on the downsampled image + scale_factor = tile_size / max(height, width) + z = z.to(device) + downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact') + # use nearest-exact to keep statictics as close as possible + print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image') + + # ======= Special thanks to @Kahsolt for distribution shift issue ======= # + # The downsampling will heavily distort its mean and std, so we need to recover it. + std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True) + std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True) + downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old + del std_old, mean_old, std_new, mean_new + # occasionally the std_new is too small or too large, which exceeds the range of float16 + # so we need to clamp it to max z's range. + downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max()) + estimate_task_queue = clone_task_queue(single_task_queue) + if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix): + single_task_queue = estimate_task_queue + del downsampled_z + + task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)] + + # Dummy result + result = None + result_approx = None + try: + with devices.autocast(): + result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu() + except: pass + # Free memory of input latent tensor + del z + + # Task queue execution + pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ") + + # execute the task back and forth when switch tiles so that we always + # keep one tile on the GPU to reduce unnecessary data transfer + forward = True + interrupted = False + #state.interrupted = interrupted + while True: + if state.interrupted: interrupted = True ; break + + group_norm_param = GroupNormParam() + for i in range(num_tiles) if forward else reversed(range(num_tiles)): + if state.interrupted: interrupted = True ; break + + tile = tiles[i].to(device) + input_bbox = in_bboxes[i] + task_queue = task_queues[i] + + interrupted = False + while len(task_queue) > 0: + if state.interrupted: interrupted = True ; break + + # DEBUG: current task + # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape) + task = task_queue.pop(0) + if task[0] == 'pre_norm': + group_norm_param.add_tile(tile, task[1]) + break + elif task[0] == 'store_res' or task[0] == 'store_res_cpu': + task_id = 0 + res = task[1](tile) + if not self.fast_mode or task[0] == 'store_res_cpu': + res = res.cpu() + while task_queue[task_id][0] != 'add_res': + task_id += 1 + task_queue[task_id][1] = res + elif task[0] == 'add_res': + tile += task[1].to(device) + task[1] = None + else: + tile = task[1](tile) + pbar.update(1) + + if interrupted: break + + # check for NaNs in the tile. + # If there are NaNs, we abort the process to save user's time + devices.test_for_nans(tile, "vae") + + if len(task_queue) == 0: + tiles[i] = None + num_completed += 1 + if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically + result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False) + result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder) + del tile + elif i == num_tiles - 1 and forward: + forward = False + tiles[i] = tile + elif i == 0 and not forward: + forward = True + tiles[i] = tile + else: + tiles[i] = tile.cpu() + del tile + + if interrupted: break + if num_completed == num_tiles: break + + # insert the group norm task to the head of each task queue + group_norm_func = group_norm_param.summary() + if group_norm_func is not None: + for i in range(num_tiles): + task_queue = task_queues[i] + task_queue.insert(0, ('apply_norm', group_norm_func)) + + # Done! + pbar.close() + return result.to(dtype) if result is not None else result_approx.to(device, dtype=dtype) + + +class Script(scripts.Script): + + def __init__(self): + self.hooked = False + + def title(self): + return "Tiled VAE" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + tab = 't2i' if not is_img2img else 'i2i' + uid = lambda name: f'MD-{tab}-{name}' + + with ( + InputAccordion(False, label='Tiled VAE', elem_id=f'MDV-{tab}-enabled') if InputAccordion + else gr.Accordion('Tiled VAE', open=False, elem_id=f'MDV-{tab}') + as enabled + ): + with gr.Row() as tab_enable: + if not InputAccordion: + enabled = gr.Checkbox(label='Enable Tiled VAE', value=False, elem_id=uid('enable')) + vae_to_gpu = gr.Checkbox(label='Move VAE to GPU (if possible)', value=True, elem_id=uid('vae2gpu')) + + gr.HTML('Recommended to set tile sizes as large as possible before got CUDA error: out of memory.
') + with gr.Row() as tab_size: + encoder_tile_size = gr.Slider(label='Encoder Tile Size', minimum=256, maximum=4096, step=16, value=get_rcmd_enc_tsize(), elem_id=uid('enc-size')) + decoder_tile_size = gr.Slider(label='Decoder Tile Size', minimum=48, maximum=512, step=16, value=get_rcmd_dec_tsize(), elem_id=uid('dec-size')) + reset = gr.Button(value='↻ Reset', variant='tool') + reset.click(fn=lambda: [get_rcmd_enc_tsize(), get_rcmd_dec_tsize()], outputs=[encoder_tile_size, decoder_tile_size], show_progress=False) + + with gr.Row() as tab_param: + fast_encoder = gr.Checkbox(label='Fast Encoder', value=True, elem_id=uid('fastenc')) + color_fix = gr.Checkbox(label='Fast Encoder Color Fix', value=False, visible=True, elem_id=uid('fastenc-colorfix')) + fast_decoder = gr.Checkbox(label='Fast Decoder', value=True, elem_id=uid('fastdec')) + + fast_encoder.change(fn=gr_show, inputs=fast_encoder, outputs=color_fix, show_progress=False) + + return [ + enabled, + encoder_tile_size, decoder_tile_size, + vae_to_gpu, fast_decoder, fast_encoder, color_fix, + ] + + def process(self, p:Processing, + enabled:bool, + encoder_tile_size:int, decoder_tile_size:int, + vae_to_gpu:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool + ): + + # for shorthand + vae = p.sd_model.first_stage_model + encoder = vae.encoder + decoder = vae.decoder + + # undo hijack if disabled (in cases last time crashed) + if not enabled: + if self.hooked: + if isinstance(encoder.forward, VAEHook): + encoder.forward.net = None + encoder.forward = encoder.original_forward + if isinstance(decoder.forward, VAEHook): + decoder.forward.net = None + decoder.forward = decoder.original_forward + self.hooked = False + return + + if devices.get_optimal_device_name().startswith('cuda') and vae.device == devices.cpu and not vae_to_gpu: + print("[Tiled VAE] warn: VAE is not on GPU, check 'Move VAE to GPU' if possible.") + + # do hijack + kwargs = { + 'fast_decoder': fast_decoder, + 'fast_encoder': fast_encoder, + 'color_fix': color_fix, + 'to_gpu': vae_to_gpu, + } + + # save original forward (only once) + if not hasattr(encoder, 'original_forward'): setattr(encoder, 'original_forward', encoder.forward) + if not hasattr(decoder, 'original_forward'): setattr(decoder, 'original_forward', decoder.forward) + + self.hooked = True + + encoder.forward = VAEHook(encoder, encoder_tile_size, is_decoder=False, **kwargs) + decoder.forward = VAEHook(decoder, decoder_tile_size, is_decoder=True, **kwargs) + + def postprocess(self, p:Processing, processed, enabled:bool, *args): + if not enabled: return + + vae = p.sd_model.first_stage_model + encoder = vae.encoder + decoder = vae.decoder + if isinstance(encoder.forward, VAEHook): + encoder.forward.net = None + encoder.forward = encoder.original_forward + if isinstance(decoder.forward, VAEHook): + decoder.forward.net = None + decoder.forward = decoder.original_forward diff --git a/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/abstractdiffusion.cpython-310.pyc b/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/abstractdiffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e66ae03f7edca430c79106c46dd4c4c7c1a93c46 Binary files /dev/null and b/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/abstractdiffusion.cpython-310.pyc differ diff --git a/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/demofusion.cpython-310.pyc b/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/demofusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65a76e620772a7c636a546cbc3a2ac303aa288fc Binary files /dev/null and b/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/demofusion.cpython-310.pyc differ diff --git a/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/mixtureofdiffusers.cpython-310.pyc b/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/mixtureofdiffusers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5be405eb4d1624ef8cc57438115d225ef2c11847 Binary files /dev/null and b/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/mixtureofdiffusers.cpython-310.pyc differ diff --git a/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/multidiffusion.cpython-310.pyc b/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/multidiffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17f15474153a32a34eabe560a548593b8302fe87 Binary files /dev/null and b/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/__pycache__/multidiffusion.cpython-310.pyc differ diff --git a/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/abstractdiffusion.py b/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/abstractdiffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..88916f3881479355d6ac0ab0425836a09255e126 --- /dev/null +++ b/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/abstractdiffusion.py @@ -0,0 +1,747 @@ +from tile_utils.utils import * + + +class AbstractDiffusion: + + def __init__(self, p: Processing, sampler: Sampler): + self.method = self.__class__.__name__ + self.p: Processing = p + self.pbar = None + + # sampler + self.sampler_name = p.sampler_name + self.sampler_raw = sampler + self.sampler = sampler + + # fix. Kdiff 'AND' support and image editing model support + if self.is_kdiff and not hasattr(self, 'is_edit_model'): + self.is_edit_model = (shared.sd_model.cond_stage_key == "edit" # "txt" + and self.sampler.model_wrap_cfg.image_cfg_scale is not None + and self.sampler.model_wrap_cfg.image_cfg_scale != 1.0) + + # cache. final result of current sampling step, [B, C=4, H//8, W//8] + # avoiding overhead of creating new tensors and weight summing + self.x_buffer: Tensor = None + self.w: int = int(self.p.width // opt_f) # latent size + self.h: int = int(self.p.height // opt_f) + # weights for background & grid bboxes + self.weights: Tensor = torch.zeros((1, 1, self.h, self.w), device=devices.device, dtype=torch.float32) + + # FIXME: I'm trying to count the step correctly but it's not working + self.step_count = 0 + self.inner_loop_count = 0 + self.kdiff_step = -1 + + # ext. Grid tiling painting (grid bbox) + self.enable_grid_bbox: bool = False + self.tile_w: int = None + self.tile_h: int = None + self.tile_bs: int = None + self.num_tiles: int = None + self.num_batches: int = None + self.batched_bboxes: List[List[BBox]] = [] + + # ext. Region Prompt Control (custom bbox) + self.enable_custom_bbox: bool = False + self.custom_bboxes: List[CustomBBox] = [] + self.cond_basis: Cond = None + self.uncond_basis: Uncond = None + self.draw_background: bool = True # by default we draw major prompts in grid tiles + self.causal_layers: bool = None + + # ext. Noise Inversion (noise inversion) + self.noise_inverse_enabled: bool = False + self.noise_inverse_steps: int = 0 + self.noise_inverse_retouch: float = None + self.noise_inverse_renoise_strength: float = None + self.noise_inverse_renoise_kernel: int = None + self.noise_inverse_get_cache = None + self.noise_inverse_set_cache = None + self.sample_img2img_original = None + + # ext. ControlNet + self.enable_controlnet: bool = False + self.controlnet_script: ModuleType = None + self.control_tensor_batch: List[List[Tensor]] = [] + self.control_params: Dict[str, Tensor] = {} + self.control_tensor_cpu: bool = None + self.control_tensor_custom: List[List[Tensor]] = [] + + # ext. StableSR + self.enable_stablesr: bool = False + self.stablesr_script: ModuleType = None + self.stablesr_tensor: Tensor = None + self.stablesr_tensor_batch: List[Tensor] = [] + self.stablesr_tensor_custom: List[Tensor] = [] + + @property + def is_kdiff(self): + return isinstance(self.sampler_raw, KDiffusionSampler) + + @property + def is_ddim(self): + return isinstance(self.sampler_raw, CompVisSampler) + + def update_pbar(self): + if self.pbar.n >= self.pbar.total: + self.pbar.close() + else: + if self.step_count == state.sampling_step: + self.inner_loop_count += 1 + if self.inner_loop_count < self.total_bboxes: + self.pbar.update() + else: + self.step_count = state.sampling_step + self.inner_loop_count = 0 + + def reset_buffer(self, x_in:Tensor): + # Judge if the shape of x_in is the same as the shape of x_buffer + if self.x_buffer is None or self.x_buffer.shape != x_in.shape: + self.x_buffer = torch.zeros_like(x_in, device=x_in.device, dtype=x_in.dtype) + else: + self.x_buffer.zero_() + + def init_done(self): + ''' + Call this after all `init_*`, settings are done, now perform: + - settings sanity check + - pre-computations, cache init + - anything thing needed before denoising starts + ''' + + self.total_bboxes = 0 + if self.enable_grid_bbox: self.total_bboxes += self.num_batches + if self.enable_custom_bbox: self.total_bboxes += len(self.custom_bboxes) + assert self.total_bboxes > 0, "Nothing to paint! No background to draw and no custom bboxes were provided." + + self.pbar = tqdm(total=(self.total_bboxes) * state.sampling_steps, desc=f"{self.method} Sampling: ") + + ''' ↓↓↓ cond_dict utils ↓↓↓ ''' + + def _tcond_key(self, cond_dict:CondDict) -> str: + return 'crossattn' if 'crossattn' in cond_dict else 'c_crossattn' + + def get_tcond(self, cond_dict:CondDict) -> Tensor: + tcond = cond_dict[self._tcond_key(cond_dict)] + if isinstance(tcond, list): tcond = tcond[0] + return tcond + + def set_tcond(self, cond_dict:CondDict, tcond:Tensor): + key = self._tcond_key(cond_dict) + if isinstance(cond_dict[key], list): tcond = [tcond] + cond_dict[key] = tcond + + def _icond_key(self, cond_dict:CondDict) -> str: + return 'c_adm' if shared.sd_model.model.conditioning_key in ['crossattn-adm', 'adm'] else 'c_concat' + + def get_icond(self, cond_dict:CondDict) -> Tensor: + ''' icond differs for different models (inpaint/unclip model) ''' + key = self._icond_key(cond_dict) + icond = cond_dict[key] + if isinstance(icond, list): icond = icond[0] + return icond + + def set_icond(self, cond_dict:CondDict, icond:Tensor): + key = self._icond_key(cond_dict) + if isinstance(cond_dict[key], list): icond = [icond] + cond_dict[key] = icond + + def _vcond_key(self, cond_dict:CondDict) -> Optional[str]: + return 'vector' if 'vector' in cond_dict else None + + def get_vcond(self, cond_dict:CondDict) -> Optional[Tensor]: + ''' vector for SDXL ''' + key = self._vcond_key(cond_dict) + return cond_dict.get(key) + + def set_vcond(self, cond_dict:CondDict, vcond:Optional[Tensor]): + key = self._vcond_key(cond_dict) + if key is not None: + cond_dict[key] = vcond + + def make_cond_dict(self, cond_in:CondDict, tcond:Tensor, icond:Tensor, vcond:Tensor=None) -> CondDict: + ''' copy & replace the content, returns a new object ''' + cond_out = cond_in.copy() + self.set_tcond(cond_out, tcond) + self.set_icond(cond_out, icond) + self.set_vcond(cond_out, vcond) + return cond_out + + ''' ↓↓↓ extensive functionality ↓↓↓ ''' + + @grid_bbox + def init_grid_bbox(self, tile_w:int, tile_h:int, overlap:int, tile_bs:int): + self.enable_grid_bbox = True + + self.tile_w = min(tile_w, self.w) + self.tile_h = min(tile_h, self.h) + overlap = max(0, min(overlap, min(tile_w, tile_h) - 4)) + # split the latent into overlapped tiles, then batching + # weights basically indicate how many times a pixel is painted + bboxes, weights = split_bboxes(self.w, self.h, self.tile_w, self.tile_h, overlap, self.get_tile_weights()) + self.weights += weights + self.num_tiles = len(bboxes) + self.num_batches = math.ceil(self.num_tiles / tile_bs) + self.tile_bs = math.ceil(len(bboxes) / self.num_batches) # optimal_batch_size + self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)] + + @grid_bbox + def get_tile_weights(self) -> Union[Tensor, float]: + return 1.0 + + + @custom_bbox + def init_custom_bbox(self, bbox_settings:Dict[int,BBoxSettings], draw_background:bool, causal_layers:bool): + self.enable_custom_bbox = True + + self.causal_layers = causal_layers + self.draw_background = draw_background + if not draw_background: + self.enable_grid_bbox = False + self.weights.zero_() + + self.custom_bboxes: List[CustomBBox] = [] + for bbox_setting in bbox_settings.values(): + e, x, y, w, h, p, n, blend_mode, feather_ratio, seed = bbox_setting + if not e or x > 1.0 or y > 1.0 or w <= 0.0 or h <= 0.0: continue + x = int(x * self.w) + y = int(y * self.h) + w = math.ceil(w * self.w) + h = math.ceil(h * self.h) + x = max(0, x) + y = max(0, y) + w = min(self.w - x, w) + h = min(self.h - y, h) + self.custom_bboxes.append(CustomBBox(x, y, w, h, p, n, blend_mode, feather_ratio, seed)) + + if len(self.custom_bboxes) == 0: + self.enable_custom_bbox = False + return + + # prepare cond + p = self.p + prompts = p.all_prompts[:p.batch_size] + neg_prompts = p.all_negative_prompts[:p.batch_size] + for bbox in self.custom_bboxes: + bbox.cond, bbox.extra_network_data = Condition.get_custom_cond(prompts, bbox.prompt, p.steps, p.styles) + bbox.uncond = Condition.get_uncond(Prompt.append_prompt(neg_prompts, bbox.neg_prompt), p.steps, p.styles) + self.cond_basis = Condition.get_cond(prompts, p.steps) + self.uncond_basis = Condition.get_uncond(neg_prompts, p.steps) + + @custom_bbox + def reconstruct_custom_cond(self, org_cond:CondDict, custom_cond:Cond, custom_uncond:Uncond, bbox:CustomBBox) -> Tuple[List, Tensor, Uncond, Tensor]: + image_conditioning = None + if isinstance(org_cond, dict): + icond = self.get_icond(org_cond) + if icond.shape[2:] == (self.h, self.w): # img2img + icond = icond[bbox.slicer] + image_conditioning = icond + + sampler_step = self.sampler.model_wrap_cfg.step + tensor = Condition.reconstruct_cond(custom_cond, sampler_step) + custom_uncond = Condition.reconstruct_uncond(custom_uncond, sampler_step) + return tensor, custom_uncond, image_conditioning + + @custom_bbox + def kdiff_custom_forward(self, x_tile:Tensor, sigma_in:Tensor, original_cond:CondDict, bbox_id:int, bbox:CustomBBox, forward_func:Callable) -> Tensor: + ''' + The inner kdiff noise prediction is usually batched. + We need to unwrap the inside loop to simulate the batched behavior. + This can be extremely tricky. + ''' + + sampler_step = self.sampler.model_wrap_cfg.step + if self.kdiff_step != sampler_step: + self.kdiff_step = sampler_step + self.kdiff_step_bbox = [-1 for _ in range(len(self.custom_bboxes))] + self.tensor = {} # {int: Tensor[cond]} + self.uncond = {} # {int: Tensor[cond]} + self.image_cond_in = {} + # Initialize global prompts just for estimate the behavior of kdiff + self.real_tensor = Condition.reconstruct_cond(self.cond_basis, sampler_step) + self.real_uncond = Condition.reconstruct_uncond(self.uncond_basis, sampler_step) + # reset the progress for all bboxes + self.a = [0 for _ in range(len(self.custom_bboxes))] + + if self.kdiff_step_bbox[bbox_id] != sampler_step: + # When a new step starts for a bbox, we need to judge whether the tensor is batched. + self.kdiff_step_bbox[bbox_id] = sampler_step + + tensor, uncond, image_cond_in = self.reconstruct_custom_cond(original_cond, bbox.cond, bbox.uncond, bbox) + + if self.real_tensor.shape[1] == self.real_uncond.shape[1]: + if shared.batch_cond_uncond: + # when the real tensor is with equal length, all information is contained in x_tile. + # we simulate the batched behavior and compute all the tensors in one go. + if tensor.shape[1] == uncond.shape[1]: + # When our prompt tensor is with equal length, we can directly their code. + if not self.is_edit_model: + cond = torch.cat([tensor, uncond]) + else: + cond = torch.cat([tensor, uncond, uncond]) + self.set_custom_controlnet_tensors(bbox_id, x_tile.shape[0]) + self.set_custom_stablesr_tensors(bbox_id) + return forward_func( + x_tile, + sigma_in, + cond=self.make_cond_dict(original_cond, cond, image_cond_in), + ) + else: + # When not, we need to pass the tensor to UNet separately. + x_out = torch.zeros_like(x_tile) + cond_size = tensor.shape[0] + self.set_custom_controlnet_tensors(bbox_id, cond_size) + self.set_custom_stablesr_tensors(bbox_id) + cond_out = forward_func( + x_tile [:cond_size], + sigma_in[:cond_size], + cond=self.make_cond_dict(original_cond, tensor, image_cond_in[:cond_size]), + ) + uncond_size = uncond.shape[0] + self.set_custom_controlnet_tensors(bbox_id, uncond_size) + self.set_custom_stablesr_tensors(bbox_id) + uncond_out = forward_func( + x_tile [cond_size:cond_size+uncond_size], + sigma_in[cond_size:cond_size+uncond_size], + cond=self.make_cond_dict(original_cond, uncond, image_cond_in[cond_size:cond_size+uncond_size]), + ) + x_out[:cond_size] = cond_out + x_out[cond_size:cond_size+uncond_size] = uncond_out + if self.is_edit_model: + x_out[cond_size+uncond_size:] = uncond_out + return x_out + + # otherwise, the x_tile is only a partial batch. + # We have to denoise in different runs. + # We store the prompt and neg_prompt tensors for current bbox + self.tensor[bbox_id] = tensor + self.uncond[bbox_id] = uncond + self.image_cond_in[bbox_id] = image_cond_in + + # Now we get current batch of prompt and neg_prompt tensors + tensor: Tensor = self.tensor[bbox_id] + uncond: Tensor = self.uncond[bbox_id] + batch_size = x_tile.shape[0] + # get the start and end index of the current batch + a = self.a[bbox_id] + b = a + batch_size + self.a[bbox_id] += batch_size + + if self.real_tensor.shape[1] == self.real_uncond.shape[1]: + # When use --lowvram or --medvram, kdiff will slice the cond and uncond with [a:b] + # So we need to slice our tensor and uncond with the same index as original kdiff. + + # --- original code in kdiff --- + # if not self.is_edit_model: + # cond = torch.cat([tensor, uncond]) + # else: + # cond = torch.cat([tensor, uncond, uncond]) + # cond = cond[a:b] + # ------------------------------ + + # The original kdiff code is to concat and then slice, but this cannot apply to + # our custom prompt tensor when tensor.shape[1] != uncond.shape[1]. So we adapt it. + cond_in, uncond_in = None, None + # Slice the [prompt, neg prompt, (possibly) neg prompt] with [a:b] + if not self.is_edit_model: + if b <= tensor.shape[0]: cond_in = tensor[a:b] + elif a >= tensor.shape[0]: cond_in = uncond[a-tensor.shape[0]:b-tensor.shape[0]] + else: + cond_in = tensor[a:] + uncond_in = uncond[:b-tensor.shape[0]] + else: + if b <= tensor.shape[0]: + cond_in = tensor[a:b] + elif b > tensor.shape[0] and b <= tensor.shape[0] + uncond.shape[0]: + if a>= tensor.shape[0]: + cond_in = uncond[a-tensor.shape[0]:b-tensor.shape[0]] + else: + cond_in = tensor[a:] + uncond_in = uncond[:b-tensor.shape[0]] + else: + if a >= tensor.shape[0] + uncond.shape[0]: + cond_in = uncond[a-tensor.shape[0]-uncond.shape[0]:b-tensor.shape[0]-uncond.shape[0]] + elif a >= tensor.shape[0]: + cond_in = torch.cat([uncond[a-tensor.shape[0]:], uncond[:b-tensor.shape[0]-uncond.shape[0]]]) + + if uncond_in is None or tensor.shape[1] == uncond.shape[1]: + # If the tensor can be passed to UNet in one go, do it. + if uncond_in is not None: + cond_in = torch.cat([cond_in, uncond_in]) + self.set_custom_controlnet_tensors(bbox_id, x_tile.shape[0]) + self.set_custom_stablesr_tensors(bbox_id) + return forward_func( + x_tile, + sigma_in, + cond=self.make_cond_dict(original_cond, cond_in, self.image_cond_in[bbox_id]), + ) + else: + # If not, we need to pass the tensor to UNet separately. + x_out = torch.zeros_like(x_tile) + cond_size = cond_in.shape[0] + self.set_custom_controlnet_tensors(bbox_id, cond_size) + self.set_custom_stablesr_tensors(bbox_id) + cond_out = forward_func( + x_tile [:cond_size], + sigma_in[:cond_size], + cond=self.make_cond_dict(original_cond, cond_in, self.image_cond_in[bbox_id]) + ) + self.set_custom_controlnet_tensors(bbox_id, uncond_in.shape[0]) + self.set_custom_stablesr_tensors(bbox_id) + uncond_out = forward_func( + x_tile [cond_size:], + sigma_in[cond_size:], + cond=self.make_cond_dict(original_cond, uncond_in, self.image_cond_in[bbox_id]) + ) + x_out[:cond_size] = cond_out + x_out[cond_size:] = uncond_out + return x_out + + # If the original prompt is with different length, + # kdiff will deal with the cond and uncond separately. + # Hence we also deal with the tensor and uncond separately. + # get the start and end index of the current batch + + if a < tensor.shape[0]: + # Deal with custom prompt tensor + if not self.is_edit_model: + c_crossattn = tensor[a:b] + else: + c_crossattn = torch.cat([tensor[a:b]], uncond) + self.set_custom_controlnet_tensors(bbox_id, x_tile.shape[0]) + self.set_custom_stablesr_tensors(bbox_id) + # complete this batch. + return forward_func( + x_tile, + sigma_in, + cond=self.make_cond_dict(original_cond, c_crossattn, self.image_cond_in[bbox_id]) + ) + else: + # if the cond is finished, we need to process the uncond. + self.set_custom_controlnet_tensors(bbox_id, uncond.shape[0]) + self.set_custom_stablesr_tensors(bbox_id) + return forward_func( + x_tile, + sigma_in, + cond=self.make_cond_dict(original_cond, uncond, self.image_cond_in[bbox_id]) + ) + + @custom_bbox + def ddim_custom_forward(self, x:Tensor, cond_in:CondDict, bbox:CustomBBox, ts:Tensor, forward_func:Callable, *args, **kwargs) -> Tensor: + ''' draw custom bbox ''' + + tensor, uncond, image_conditioning = self.reconstruct_custom_cond(cond_in, bbox.cond, bbox.uncond, bbox) + + cond = tensor + # for DDIM, shapes definitely match. So we dont need to do the same thing as in the KDIFF sampler. + if uncond.shape[1] < cond.shape[1]: + last_vector = uncond[:, -1:] + last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond.shape[1], 1]) + uncond = torch.hstack([uncond, last_vector_repeated]) + elif uncond.shape[1] > cond.shape[1]: + uncond = uncond[:, :cond.shape[1]] + + # Wrap the image conditioning back up since the DDIM code can accept the dict directly. + # Note that they need to be lists because it just concatenates them later. + if image_conditioning is not None: + cond = self.make_cond_dict(cond_in, cond, image_conditioning) + uncond = self.make_cond_dict(cond_in, uncond, image_conditioning) + + # We cannot determine the batch size here for different methods, so delay it to the forward_func. + return forward_func(x, cond, ts, unconditional_conditioning=uncond, *args, **kwargs) + + + @controlnet + def init_controlnet(self, controlnet_script:ModuleType, control_tensor_cpu:bool): + self.enable_controlnet = True + + self.controlnet_script = controlnet_script + self.control_tensor_cpu = control_tensor_cpu + self.control_tensor_batch = None + self.control_params = None + self.control_tensor_custom = [] + + self.prepare_controlnet_tensors() + + @controlnet + def reset_controlnet_tensors(self): + if not self.enable_controlnet: return + if self.control_tensor_batch is None: return + + for param_id in range(len(self.control_params)): + self.control_params[param_id].hint_cond = self.org_control_tensor_batch[param_id] + + @controlnet + def prepare_controlnet_tensors(self, refresh:bool=False): + ''' Crop the control tensor into tiles and cache them ''' + + if not refresh: + if self.control_tensor_batch is not None or self.control_params is not None: return + + if not self.enable_controlnet or self.controlnet_script is None: return + + latest_network = self.controlnet_script.latest_network + if latest_network is None or not hasattr(latest_network, 'control_params'): return + + self.control_params = latest_network.control_params + tensors = [param.hint_cond for param in latest_network.control_params] + self.org_control_tensor_batch = tensors + + if len(tensors) == 0: return + + self.control_tensor_batch = [] + for i in range(len(tensors)): + control_tile_list = [] + control_tensor = tensors[i] + for bboxes in self.batched_bboxes: + single_batch_tensors = [] + for bbox in bboxes: + if len(control_tensor.shape) == 3: + control_tensor.unsqueeze_(0) + control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f] + single_batch_tensors.append(control_tile) + control_tile = torch.cat(single_batch_tensors, dim=0) + if self.control_tensor_cpu: + control_tile = control_tile.cpu() + control_tile_list.append(control_tile) + self.control_tensor_batch.append(control_tile_list) + + if len(self.custom_bboxes) > 0: + custom_control_tile_list = [] + for bbox in self.custom_bboxes: + if len(control_tensor.shape) == 3: + control_tensor.unsqueeze_(0) + control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f] + if self.control_tensor_cpu: + control_tile = control_tile.cpu() + custom_control_tile_list.append(control_tile) + self.control_tensor_custom.append(custom_control_tile_list) + + @controlnet + def switch_controlnet_tensors(self, batch_id:int, x_batch_size:int, tile_batch_size:int, is_denoise=False): + if not self.enable_controlnet: return + if self.control_tensor_batch is None: return + + for param_id in range(len(self.control_params)): + control_tile = self.control_tensor_batch[param_id][batch_id] + if self.is_kdiff: + all_control_tile = [] + for i in range(tile_batch_size): + this_control_tile = [control_tile[i].unsqueeze(0)] * x_batch_size + all_control_tile.append(torch.cat(this_control_tile, dim=0)) + control_tile = torch.cat(all_control_tile, dim=0) + else: + control_tile = control_tile.repeat([x_batch_size if is_denoise else x_batch_size * 2, 1, 1, 1]) + self.control_params[param_id].hint_cond = control_tile.to(devices.device) + + @controlnet + def set_custom_controlnet_tensors(self, bbox_id:int, repeat_size:int): + if not self.enable_controlnet: return + if not len(self.control_tensor_custom): return + + for param_id in range(len(self.control_params)): + control_tensor = self.control_tensor_custom[param_id][bbox_id].to(devices.device) + self.control_params[param_id].hint_cond = control_tensor.repeat((repeat_size, 1, 1, 1)) + + + @stablesr + def init_stablesr(self, stablesr_script:ModuleType): + if stablesr_script.stablesr_model is None: return + self.stablesr_script = stablesr_script + def set_image_hook(latent_image): + self.enable_stablesr = True + self.stablesr_tensor = latent_image + self.stablesr_tensor_batch = [] + for bboxes in self.batched_bboxes: + single_batch_tensors = [] + for bbox in bboxes: + stablesr_tile = self.stablesr_tensor[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]] + single_batch_tensors.append(stablesr_tile) + stablesr_tile = torch.cat(single_batch_tensors, dim=0) + self.stablesr_tensor_batch.append(stablesr_tile) + if len(self.custom_bboxes) > 0: + self.stablesr_tensor_custom = [] + for bbox in self.custom_bboxes: + stablesr_tile = self.stablesr_tensor[:, :, bbox[1]:bbox[3], bbox[0]:bbox[2]] + self.stablesr_tensor_custom.append(stablesr_tile) + + stablesr_script.stablesr_model.set_image_hooks['TiledDiffusion'] = set_image_hook + + @stablesr + def reset_stablesr_tensors(self): + if not self.enable_stablesr: return + if self.stablesr_script.stablesr_model is None: return + self.stablesr_script.stablesr_model.latent_image = self.stablesr_tensor + + @stablesr + def switch_stablesr_tensors(self, batch_id:int): + if not self.enable_stablesr: return + if self.stablesr_script.stablesr_model is None: return + if self.stablesr_tensor_batch is None: return + self.stablesr_script.stablesr_model.latent_image = self.stablesr_tensor_batch[batch_id] + + @stablesr + def set_custom_stablesr_tensors(self, bbox_id:int): + if not self.enable_stablesr: return + if self.stablesr_script.stablesr_model is None: return + if not len(self.stablesr_tensor_custom): return + self.stablesr_script.stablesr_model.latent_image = self.stablesr_tensor_custom[bbox_id] + + + @noise_inverse + def init_noise_inverse(self, steps:int, retouch:float, get_cache_callback, set_cache_callback, renoise_strength:float, renoise_kernel:int): + self.noise_inverse_enabled = True + self.noise_inverse_steps = steps + self.noise_inverse_retouch = float(retouch) + self.noise_inverse_renoise_strength = float(renoise_strength) + self.noise_inverse_renoise_kernel = int(renoise_kernel) + if self.sample_img2img_original is None: + self.sample_img2img_original = self.sampler_raw.sample_img2img + self.sampler_raw.sample_img2img = MethodType(self.sample_img2img, self.sampler_raw) + self.noise_inverse_set_cache = set_cache_callback + self.noise_inverse_get_cache = get_cache_callback + + @noise_inverse + @keep_signature + def sample_img2img(self, sampler: KDiffusionSampler, p:ProcessingImg2Img, + x:Tensor, noise:Tensor, conditioning, unconditional_conditioning, + steps=None, image_conditioning=None): + # noise inverse sampling - renoise mask + import torch.nn.functional as F + renoise_mask = None + if self.noise_inverse_renoise_strength > 0: + image = p.init_images[0] + # convert to grayscale with PIL + image = image.convert('L') + np_mask = get_retouch_mask(np.asarray(image), self.noise_inverse_renoise_kernel) + renoise_mask = torch.from_numpy(np_mask).to(noise.device) + # resize retouch mask to match noise size + renoise_mask = 1 - F.interpolate(renoise_mask.unsqueeze(0).unsqueeze(0), size=noise.shape[-2:], mode='bilinear').squeeze(0).squeeze(0) + renoise_mask *= self.noise_inverse_renoise_strength + renoise_mask = torch.clamp(renoise_mask, 0, 1) + + prompts = p.all_prompts[:p.batch_size] + + latent = None + # try to use cached latent to save huge amount of time. + cached_latent: NoiseInverseCache = self.noise_inverse_get_cache() + if cached_latent is not None and \ + cached_latent.model_hash == p.sd_model.sd_model_hash and \ + cached_latent.noise_inversion_steps == self.noise_inverse_steps and \ + len(cached_latent.prompts) == len(prompts) and \ + all([cached_latent.prompts[i] == prompts[i] for i in range(len(prompts))]) and \ + abs(cached_latent.retouch - self.noise_inverse_retouch) < 0.01 and \ + cached_latent.x0.shape == p.init_latent.shape and \ + torch.abs(cached_latent.x0.to(p.init_latent.device) - p.init_latent).sum() < 100: # the 100 is an arbitrary threshold copy-pasted from the img2img alt code + # use cached noise + print('[Tiled Diffusion] Your checkpoint, image, prompts, inverse steps, and retouch params are all unchanged.') + print('[Tiled Diffusion] Noise Inversion will use the cached noise from the previous run. To clear the cache, click the Free GPU button.') + latent = cached_latent.xt.to(noise.device) + if latent is None: + # run noise inversion + shared.state.job_count += 1 + latent = self.find_noise_for_image_sigma_adjustment(sampler.model_wrap, self.noise_inverse_steps, prompts) + shared.state.nextjob() + self.noise_inverse_set_cache(p.init_latent.clone().cpu(), latent.clone().cpu(), prompts) + # The cache is only 1 latent image and is very small (16 MB for 8192 * 8192 image), so we don't need to worry about memory leakage. + + # calculate sampling steps + adjusted_steps, _ = sd_samplers_common.setup_img2img_steps(p, steps) + sigmas = sampler.get_sigmas(p, adjusted_steps) + inverse_noise = latent - (p.init_latent / sigmas[0]) + + # inject noise to high-frequency area so that the details won't lose too much + if renoise_mask is not None: + # If the background is not drawn, we need to filter out the un-drawn pixels and reweight foreground with feather mask + # This is to enable the renoise mask in regional inpainting + if not self.enable_grid_bbox: + background_count = torch.zeros((1, 1, noise.shape[2], noise.shape[3]), device=noise.device) + foreground_noise = torch.zeros_like(noise) + foreground_weight = torch.zeros((1, 1, noise.shape[2], noise.shape[3]), device=noise.device) + foreground_count = torch.zeros((1, 1, noise.shape[2], noise.shape[3]), device=noise.device) + for bbox in self.custom_bboxes: + if bbox.blend_mode == BlendMode.BACKGROUND: + background_count[bbox.slicer] += 1 + elif bbox.blend_mode == BlendMode.FOREGROUND: + foreground_noise [bbox.slicer] += noise[bbox.slicer] + foreground_weight[bbox.slicer] += bbox.feather_mask + foreground_count [bbox.slicer] += 1 + background_noise = torch.where(background_count > 0, noise, 0) + foreground_noise = torch.where(foreground_count > 0, foreground_noise / foreground_count, 0) + foreground_weight = torch.where(foreground_count > 0, foreground_weight / foreground_count, 0) + noise = background_noise * (1 - foreground_weight) + foreground_noise * foreground_weight + del background_noise, foreground_noise, foreground_weight, background_count, foreground_count + combined_noise = ((1 - renoise_mask) * inverse_noise + renoise_mask * noise) / ((renoise_mask**2 + (1 - renoise_mask)**2) ** 0.5) + else: + combined_noise = inverse_noise + + # use the estimated noise for the original img2img sampling + return self.sample_img2img_original(p, x, combined_noise, conditioning, unconditional_conditioning, steps, image_conditioning) + + @noise_inverse + @torch.no_grad() + def find_noise_for_image_sigma_adjustment(self, dnw, steps, prompts:List[str]) -> Tensor: + ''' + Migrate from the built-in script img2imgalt.py + Tiled noise inverse for better image upscaling + ''' + import k_diffusion as K + assert self.p.sampler_name == 'Euler' + + x = self.p.init_latent + s_in = x.new_ones([x.shape[0]]) + skip = 1 if shared.sd_model.parameterization == "v" else 0 + sigmas = dnw.get_sigmas(steps).flip(0) + + cond = self.p.sd_model.get_learned_conditioning(prompts) + if isinstance(cond, Tensor): # SD1/SD2 + cond_dict_dummy = { + 'c_crossattn': [], # List[Tensor] + 'c_concat': [], # List[Tensor] + } + cond_in = self.make_cond_dict(cond_dict_dummy, cond, self.p.image_conditioning) + else: # SDXL + cond_dict_dummy = { + 'crossattn': None, # Tensor + 'vector': None, # Tensor + 'c_concat': [], # List[Tensor] + } + cond_in = self.make_cond_dict(cond_dict_dummy, cond['crossattn'], self.p.image_conditioning, cond['vector']) + + state.sampling_steps = steps + pbar = tqdm(total=steps, desc='Noise Inversion') + for i in range(1, len(sigmas)): + if state.interrupted: return x + + state.sampling_step += 1 + + x_in = x + sigma_in = torch.cat([sigmas[i] * s_in]) + c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]] + + t = dnw.sigma_to_t(sigma_in) + t = t / self.noise_inverse_retouch + + eps = self.get_noise(x_in * c_in, t, cond_in, steps - i) + denoised = x_in + eps * c_out + + # Euler method: + d = (x - denoised) / sigmas[i] + dt = sigmas[i] - sigmas[i - 1] + x = x + d * dt + + sd_samplers_common.store_latent(x) + + # This is neccessary to save memory before the next iteration + del x_in, sigma_in, c_out, c_in, t, + del eps, denoised, d, dt + + pbar.update(1) + pbar.close() + + return x / sigmas[-1] + + @noise_inverse + @torch.no_grad() + def get_noise(self, x_in: Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor: + raise NotImplementedError diff --git a/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/demofusion.py b/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/demofusion.py new file mode 100644 index 0000000000000000000000000000000000000000..758ccfe0d13c6e92b878660a475905cf325a29fd --- /dev/null +++ b/extensions/multidiffusion-upscaler-for-automatic1111/tile_methods/demofusion.py @@ -0,0 +1,353 @@ +from tile_methods.abstractdiffusion import AbstractDiffusion +from tile_utils.utils import * +import torch.nn.functional as F +import random +from copy import deepcopy +import inspect +from modules import sd_samplers_common + + +class DemoFusion(AbstractDiffusion): + """ + DemoFusion Implementation + https://arxiv.org/abs/2311.16973 + """ + + def __init__(self, p:Processing, *args, **kwargs): + super().__init__(p, *args, **kwargs) + assert p.sampler_name != 'UniPC', 'Demofusion is not compatible with UniPC!' + + + def hook(self): + steps, self.t_enc = sd_samplers_common.setup_img2img_steps(self.p, None) + + self.sampler.model_wrap_cfg.forward_ori = self.sampler.model_wrap_cfg.forward + self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward + self.sampler.model_wrap_cfg.forward = self.forward_one_step + if self.is_kdiff: + self.sampler: KDiffusionSampler + self.sampler.model_wrap_cfg: CFGDenoiserKDiffusion + self.sampler.model_wrap_cfg.inner_model: Union[CompVisDenoiser, CompVisVDenoiser] + else: + self.sampler: CompVisSampler + self.sampler.model_wrap_cfg: CFGDenoiserTimesteps + self.sampler.model_wrap_cfg.inner_model: Union[CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser] + self.timesteps = self.sampler.get_timesteps(self.p, steps) + + @staticmethod + def unhook(): + if hasattr(shared.sd_model, 'apply_model_ori'): + shared.sd_model.apply_model = shared.sd_model.apply_model_ori + del shared.sd_model.apply_model_ori + + def reset_buffer(self, x_in:Tensor): + super().reset_buffer(x_in) + + + + def repeat_tensor(self, x:Tensor, n:int) -> Tensor: + ''' repeat the tensor on it's first dim ''' + if n == 1: return x + B = x.shape[0] + r_dims = len(x.shape) - 1 + if B == 1: # batch_size = 1 (not `tile_batch_size`) + shape = [n] + [-1] * r_dims # [N, -1, ...] + return x.expand(shape) # `expand` is much lighter than `tile` + else: + shape = [n] + [1] * r_dims # [N, 1, ...] + return x.repeat(shape) + + def repeat_cond_dict(self, cond_in:CondDict, bboxes,mode) -> CondDict: + ''' repeat all tensors in cond_dict on it's first dim (for a batch of tiles), returns a new object ''' + # n_repeat + n_rep = len(bboxes) + # txt cond + tcond = self.get_tcond(cond_in) # [B=1, L, D] => [B*N, L, D] + tcond = self.repeat_tensor(tcond, n_rep) + # img cond + icond = self.get_icond(cond_in) + if icond.shape[2:] == (self.h, self.w): # img2img, [B=1, C, H, W] + if mode == 0: + if self.p.random_jitter: + jitter_range = self.jitter_range + icond = F.pad(icond,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) + icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0) + else: + icond = torch.cat([icond[:,:,bbox[1]::self.p.current_scale_num,bbox[0]::self.p.current_scale_num] for bbox in bboxes], dim=0) + else: # txt2img, [B=1, C=5, H=1, W=1] + icond = self.repeat_tensor(icond, n_rep) + + # vec cond (SDXL) + vcond = self.get_vcond(cond_in) # [B=1, D] + if vcond is not None: + vcond = self.repeat_tensor(vcond, n_rep) # [B*N, D] + return self.make_cond_dict(cond_in, tcond, icond, vcond) + + + def global_split_bboxes(self): + cols = self.p.current_scale_num + rows = cols + + bbox_list = [] + for row in range(rows): + y = row + for col in range(cols): + x = col + bbox = (x, y) + bbox_list.append(bbox) + + return bbox_list+bbox_list if self.p.mixture else bbox_list + + def split_bboxes_jitter(self,w_l:int, h_l:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]: + cols = math.ceil((w_l - overlap) / (tile_w - overlap)) + rows = math.ceil((h_l - overlap) / (tile_h - overlap)) + if rows==0: + rows=1 + if cols == 0: + cols=1 + dx = (w_l - tile_w) / (cols - 1) if cols > 1 else 0 + dy = (h_l - tile_h) / (rows - 1) if rows > 1 else 0 + bbox_list: List[BBox] = [] + self.jitter_range = 0 + for row in range(rows): + for col in range(cols): + h = min(int(row * dy), h_l - tile_h) + w = min(int(col * dx), w_l - tile_w) + if self.p.random_jitter: + self.jitter_range = min(max((min(self.w, self.h)-self.stride)//4,0),min(int(self.window_size/2),int(self.overlap/2))) + jitter_range = self.jitter_range + w_jitter = 0 + h_jitter = 0 + if (w != 0) and (w+tile_w != w_l): + w_jitter = random.randint(-jitter_range, jitter_range) + elif (w == 0) and (w + tile_w != w_l): + w_jitter = random.randint(-jitter_range, 0) + elif (w != 0) and (w + tile_w == w_l): + w_jitter = random.randint(0, jitter_range) + if (h != 0) and (h + tile_h != h_l): + h_jitter = random.randint(-jitter_range, jitter_range) + elif (h == 0) and (h + tile_h != h_l): + h_jitter = random.randint(-jitter_range, 0) + elif (h != 0) and (h + tile_h == h_l): + h_jitter = random.randint(0, jitter_range) + h +=(h_jitter + jitter_range) + w += (w_jitter + jitter_range) + + bbox = BBox(w, h, tile_w, tile_h) + bbox_list.append(bbox) + return bbox_list, None + + @grid_bbox + def get_views(self, overlap:int, tile_bs:int,tile_bs_g:int): + self.enable_grid_bbox = True + self.tile_w = self.window_size + self.tile_h = self.window_size + + self.overlap = max(0, min(overlap, self.window_size - 4)) + + self.stride = max(4,self.window_size - self.overlap) + + # split the latent into overlapped tiles, then batching + # weights basically indicate how many times a pixel is painted + bboxes, _ = self.split_bboxes_jitter(self.w, self.h, self.tile_w, self.tile_h, self.overlap, self.get_tile_weights()) + self.num_tiles = len(bboxes) + self.num_batches = math.ceil(self.num_tiles / tile_bs) + self.tile_bs = math.ceil(len(bboxes) / self.num_batches) # optimal_batch_size + self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)] + + global_bboxes = self.global_split_bboxes() + self.global_num_tiles = len(global_bboxes) + self.global_num_batches = math.ceil(self.global_num_tiles / tile_bs_g) + self.global_tile_bs = math.ceil(len(global_bboxes) / self.global_num_batches) + self.global_batched_bboxes = [global_bboxes[i*self.global_tile_bs:(i+1)*self.global_tile_bs] for i in range(self.global_num_batches)] + + def gaussian_kernel(self,kernel_size=3, sigma=1.0, channels=3): + x_coord = torch.arange(kernel_size, device=devices.device) + gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] + kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) + + return kernel + + def gaussian_filter(self,latents, kernel_size=3, sigma=1.0): + channels = latents.shape[1] + kernel = self.gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) + blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) + + return blurred_latents + + + + ''' ↓↓↓ kernel hijacks ↓↓↓ ''' + @torch.no_grad() + @keep_signature + def forward_one_step(self, x_in, sigma, **kwarg): + if self.is_kdiff: + x_noisy = self.p.x + self.p.noise * sigma[0] + else: + alphas_cumprod = self.p.sd_model.alphas_cumprod + sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) + sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) + x_noisy = self.p.x*sqrt_alpha_cumprod + self.p.noise * sqrt_one_minus_alpha_cumprod + + self.cosine_factor = 0.5 * (1 + torch.cos(torch.pi *torch.tensor(((self.p.current_step + 1) / (self.t_enc+1))))) + + c1 = self.cosine_factor ** self.p.cosine_scale_1 + + x_in = x_in*(1 - c1) + x_noisy * c1 + + if self.p.random_jitter: + jitter_range = self.jitter_range + else: + jitter_range = 0 + x_in_ = F.pad(x_in,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) + _,_,H,W = x_in.shape + + self.sampler.model_wrap_cfg.inner_model.forward = self.sample_one_step + self.repeat_3 = False + + x_out = self.sampler.model_wrap_cfg.forward_ori(x_in_,sigma, **kwarg) + self.sampler.model_wrap_cfg.inner_model.forward = self.sampler_forward + x_out = x_out[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W] + + return x_out + + + @torch.no_grad() + @keep_signature + def sample_one_step(self, x_in, sigma, cond): + assert LatentDiffusion.apply_model + def repeat_func_1(x_tile:Tensor, bboxes,mode=0) -> Tensor: + sigma_tile = self.repeat_tensor(sigma, len(bboxes)) + cond_tile = self.repeat_cond_dict(cond, bboxes,mode) + return self.sampler_forward(x_tile, sigma_tile, cond=cond_tile) + + def repeat_func_2(x_tile:Tensor, bboxes,mode=0) -> Tuple[Tensor, Tensor]: + n_rep = len(bboxes) + ts_tile = self.repeat_tensor(sigma, n_rep) + if isinstance(cond, dict): # FIXME: when will enter this branch? + cond_tile = self.repeat_cond_dict(cond, bboxes,mode) + else: + cond_tile = self.repeat_tensor(cond, n_rep) + return self.sampler_forward(x_tile, ts_tile, cond=cond_tile) + + def repeat_func_3(x_tile:Tensor, bboxes,mode=0): + sigma_in_tile = sigma.repeat(len(bboxes)) + cond_out = self.repeat_cond_dict(cond, bboxes,mode) + x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=cond_out) + return x_tile_out + + if self.repeat_3: + repeat_func = repeat_func_3 + self.repeat_3 = False + elif self.is_kdiff: + repeat_func = repeat_func_1 + else: + repeat_func = repeat_func_2 + N,_,_,_ = x_in.shape + + + self.x_buffer = torch.zeros_like(x_in) + self.weights = torch.zeros_like(x_in) + + for batch_id, bboxes in enumerate(self.batched_bboxes): + if state.interrupted: return x_in + x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) + x_tile_out = repeat_func(x_tile, bboxes) + # de-batching + for i, bbox in enumerate(bboxes): + self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] + self.weights[bbox.slicer] += 1 + self.weights = torch.where(self.weights == 0, torch.tensor(1), self.weights) #Prevent NaN from appearing in random_jitter mode + + x_local = self.x_buffer/self.weights + + self.x_buffer = torch.zeros_like(self.x_buffer) + self.weights = torch.zeros_like(self.weights) + + std_, mean_ = x_in.std(), x_in.mean() + c3 = 0.99 * self.cosine_factor ** self.p.cosine_scale_3 + 1e-2 + if self.p.gaussian_filter: + x_in_g = self.gaussian_filter(x_in, kernel_size=(2*self.p.current_scale_num-1), sigma=self.sig*c3) + x_in_g = (x_in_g - x_in_g.mean()) / x_in_g.std() * std_ + mean_ + + if not hasattr(self.p.sd_model, 'apply_model_ori'): + self.p.sd_model.apply_model_ori = self.p.sd_model.apply_model + self.p.sd_model.apply_model = self.apply_model_hijack + x_global = torch.zeros_like(x_local) + jitter_range = self.jitter_range + end = x_global.shape[3]-jitter_range + + current_num = 0 + if self.p.mixture: + for batch_id, bboxes in enumerate(self.global_batched_bboxes): + current_num += len(bboxes) + if current_num > (self.global_num_tiles//2) and (current_num-self.global_tile_bs) < (self.global_num_tiles//2): + res = len(bboxes) - (current_num - self.global_num_tiles//2) + x_in_i = torch.cat([x_in[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] if idx
'
+
+ installstatus = None
+
+ for version in reversed(item['modelVersions']):
+ for file in version.get('files', []):
+ file_name = file['name']
+ file_sha256 = file.get('hashes', {}).get('SHA256', "").upper()
+
+ name_match = file_name in existing_files
+ sha256_match = file_sha256 in existing_files_sha256
+ if name_match or sha256_match:
+ if version == item['modelVersions'][0]:
+ installstatus = "civmodelcardinstalled"
+ else:
+ installstatus = "civmodelcardoutdated"
+ model_name_js = model_name.replace("'", "\\'")
+ model_string = escape(f"{model_name_js} ({model_id})")
+ model_card = f'')
+ if model_version is None:
+ selected_version = item['modelVersions'][0]
+ else:
+ for model in item['modelVersions']:
+ if model['name'] == model_version:
+ selected_version = model
+ break
+
+ if selected_version['trainedWords']:
+ output_training = ",".join(selected_version['trainedWords'])
+ output_training = re.sub(r'<[^>]*:[^>]*>', '', output_training)
+ output_training = re.sub(r', ?', ', ', output_training)
+ output_training = output_training.strip(', ')
+ if selected_version['baseModel']:
+ output_basemodel = selected_version['baseModel']
+ for file in selected_version['files']:
+ dl_dict[file['name']] = file['downloadUrl']
+
+ if not model_filename:
+ model_filename = file['name']
+ dl_url = file['downloadUrl']
+ gl.json_info = item
+ sha256_value = file['hashes'].get('SHA256', 'Unknown')
+
+ size = file['metadata'].get('size', 'Unknown')
+ format = file['metadata'].get('format', 'Unknown')
+ fp = file['metadata'].get('fp', 'Unknown')
+ sizeKB = file.get('sizeKB', 0) * 1024
+ filesize = _download.convert_size(sizeKB)
+
+ unique_file_name = f"{size} {format} {fp} ({filesize})"
+ is_primary = file.get('primary', False)
+ file_list.append(unique_file_name)
+ file_dict.append({
+ "format": format,
+ "sizeKB": sizeKB
+ })
+ if is_primary:
+ default_file = unique_file_name
+ model_filename = file['name']
+ dl_url = file['downloadUrl']
+ gl.json_info = item
+ sha256_value = file['hashes'].get('SHA256', 'Unknown')
+
+ safe_tensor_found = False
+ pickle_tensor_found = False
+ if is_LORA and file_dict:
+ for file_info in file_dict:
+ file_format = file_info.get("format", "")
+ if "SafeTensor" in file_format:
+ safe_tensor_found = True
+ if "PickleTensor" in file_format:
+ pickle_tensor_found = True
+
+ if safe_tensor_found and pickle_tensor_found:
+ if "PickleTensor" in file_dict[0].get("format", ""):
+ if file_dict[0].get("sizeKB", 0) <= 100:
+ model_folder = os.path.join(contenttype_folder("TextualInversion"))
+
+ model_url = selected_version.get('downloadUrl', '')
+ model_main_url = f"https://civitai.com/models/{item['id']}"
+ img_html = ''
+
+ url = f"https://civitai.com/api/v1/model-versions/{selected_version['id']}"
+ api_version = request_civit_api(url)
+
+ for index, pic in enumerate(api_version['images']):
+
+ if from_preview:
+ index = f"preview_{index}"
+
+ class_name = 'class="model-block"'
+ if pic.get('nsfwLevel') >= 4:
+ class_name = 'class="civnsfw model-block"'
+
+ img_html += f'''
+
+
+
+
+
+ '''
+
+ if meta_button:
+ img_html += f'''
+
+
+
+ '''
+ else:
+ img_html += ''
+
+ if prompt_dict:
+ img_html += ''
+ # Define the preferred order of keys
+ preferred_order = ["prompt", "negativePrompt", "seed", "Size", "Model", "Clip skip", "sampler", "steps", "cfgScale"]
+ # Loop through the keys in the preferred order and add them to the HTML
+ for key in preferred_order:
+ if key in prompt_dict:
+ value = prompt_dict[key]
+ key_map = {
+ "prompt": "Prompt",
+ "negativePrompt": "Negative prompt",
+ "seed": "Seed",
+ "Size": "Size",
+ "Model": "Model",
+ "Clip skip": "Clip skip",
+ "sampler": "Sampler",
+ "steps": "Steps",
+ "cfgScale": "CFG scale"
+ }
+ key = key_map.get(key, key)
+
+ if meta_btn:
+ img_html += f''
+ else:
+ img_html += f''
+ # Check if there are remaining keys in meta
+ remaining_keys = [key for key in prompt_dict if key not in preferred_order]
+
+ # Add the rest
+ if remaining_keys:
+ img_html += f"""
+
+
+
+
+
+ """
+ for key in remaining_keys:
+ value = prompt_dict[key]
+ img_html += f''
+ img_html = img_html + ''
+
+ img_html += '
'
+
+ img_html = img_html + ''
+ img_html = img_html + ''\
+ f'{allow_svg if item.get("allowNoCredit") else deny_svg} Use the model without crediting the creator
'\
+ f'{allow_svg if "Image" in allowCommercialUse else deny_svg} Sell images they generate
'\
+ f'{allow_svg if "Rent" in allowCommercialUse else deny_svg} Run on services that generate images for money
'\
+ f'{allow_svg if "RentCivit" in allowCommercialUse else deny_svg} Run on Civitai
'\
+ f'{allow_svg if item.get("allowDerivatives") else deny_svg} Share merges using this model
'\
+ f'{allow_svg if "Sell" in allowCommercialUse else deny_svg} Sell this model or merges using this model
'\
+ f'{allow_svg if item.get("allowDifferentLicense") else deny_svg} Have different permissions when sharing merges'\
+ '
+
+`Danbooru Tags Upsampler` と書かれたアコーディオンを開き、`Enabled` チェックボックスにチェックをいれることで拡張機能を有効化できます。
+
+パラメーターの説明:
+
+| パラメーター名 | 説明 | 例 |
+| -------------- | ----------- | ------------- |
+| **Total tag length** | これは **タグの補完後のプロンプト内のタグの総量を指定します**。 補完するタグの量ではありません。 `very short` は「タグ10個以下」, `short` は「タグ20個以下」, `long` は「タグ40個以下」、 `very long` は「それよりも多い」を意味します。 | 推奨は `long` です |
+| **Ban tags** | ここで指定された全てのタグは補完時に出現しなくなります。出てきて欲しくないタグがあるときに便利です。`*` は全ての文字列にマッチします。(例: `* background` は `simple background`、`white background` 等にマッチします) | `official alternate costume, english text, * background, ...` |
+| **Seed for upsampling tags** | この値とポジティブプロンプトが固定された場合、補完されるタグも固定されます。`-1` は毎回ことなるシードで補完することを意味します。 | 毎回異なる補完をしてほしい場合は `-1` に設定します。 |
+| **Upsampling timing** | sd-dynamic-prompts や webui の styles 機能などの、他のプロンプト加工処理が実行される前にアップサンプルするか、後にアップサンプルするかどうかです。 | `After applying other prompt processing` |
+| **Variety level** | このパラメーターは `Generation config` のプリセットです。アップサンプルされるタグの多様度を指定できます。 | `varied` |
+| **Generation config** | タグの生成に利用される LLM のパラメーターです。言語モデルの生成パラメーターに詳しくない場合は触らず、 `Variety level` を使うことをおすすめします。 ||
+
+## ショーケース
+
+| 入力のプロンプト | +アップサンプルなし | +アップサンプルあり | +
| 1girl, solo, cowboy shot (seed: 2396487241) | +
+
+ |
+ ![]() |
+
| (最終的なプロンプト) | +1girl, solo, cowboy shot | +1girl, solo, cowboy shot, ahoge, animal ears, bare shoulders, blue hair, blush, closed mouth, collarbone, collared shirt, dress, eyelashes, fox ears, fox girl, fox tail, hair between eyes, heart, long hair, long sleeves, looking at viewer, neck ribbon, ribbon, shirt, simple background, sleeves past wrists, smile, tail, white background, white dress, white shirt, yellow eyes | +
| 3girls (seed: 684589178) | +
+
+ |
+ ![]() |
+
| (最終的なプロンプト) | +3girls | +3girls, black footwear, black hair, black thighhighs, boots, bow, bowtie, chibi, closed mouth, collared shirt, flower, grey hair, hair between eyes, hair flower, hair ornament, long hair, long sleeves, looking at viewer, multiple girls, purple eyes, red eyes, shirt, short hair, sitting, smile, thighhighs, vest, white shirt, white skirt | +
| no humans, scenery (seed: 3702717413) | +
+
+ |
+ ![]() |
+
| (最終的なプロンプト) | +no humans, scenery | +no humans, scenery, animal, animal focus, bird, blue eyes, cat, dog, flower, grass, leaf, nature, petals, shadow, sitting, star (sky), sunflower, tree | +
| 1girl, frieren, sousou no frieren + (seed: 787304393) | +
+
+ |
+ ![]() |
+
| (最終的なプロンプト) | +1girl, frieren, sousou no frieren | +1girl, frieren, sousou no frieren, black pantyhose, cape, closed mouth, elf, fingernails, green eyes, grey hair, hair between eyes, long hair, long sleeves, looking at viewer, pantyhose, pointy ears, simple background, skirt, solo, twintails, white background, white skirt | +
| 入力プロンプト | +Very unvaried | +Unvaried | +Normal | +Varied | +Very varied | +
| 1girl, solo, from side | +![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+
| 1girl, frieren, sousou no frieren, | +![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+
| no humans, scenery | +![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+
+
+Open the `Danbooru Tags Upsampler` accordion and check the `Enabled` checkbox to enable this extension.
+
+Explanation of parameters:
+
+| Parameter name | Description | Example value |
+| -------------- | ----------- | ------------- |
+| **Total tag length** | This parameter can specify the amount of **total tags after completing the positive prompt**. Not the amount of completing tags. `very short` means "less than 10 tags", `short` means "less than 20 tags", `long` means "less than 40 tags" and `very long` is more than that. | `long` is recommended |
+| **Ban tags** | All tags in this field will never appear in completion tags. It's useful when you don't want to contain some specific tags. Using `*` maches to any character. (e.g. `* background` matches to `simple background`, `white background`, ...) | `official alternate costume, english text, * background, ...` |
+| **Seed for upsampling tags** | If this number and the positive prompt are fixed, the completion tags are also fixed. `-1` means "generates tags using random seed every time" | If you want to generate images with different final prompts every time, set to `-1`. |
+| **Upsampling timing** | When to upsample, before or after other prompt processing (e.g. sd-dynamic-prompts or webui's styles feature) are applied. | `After applying other prompt processings` |
+| **Variety level** | These parameters are presets of the `Generation config`. This can change the variety of upsampled tags. | `varied` |
+| **Generation config** | LLM parameters of generating tags. It's recommended not to touch if you are not familiar with language model's generation parameters, and use `Variety level` option instead. ||
+
+
+
+
+## Showcases
+
+| Input prompt | +Without upsampling | +With upsampling | +
| 1girl, solo, cowboy shot (seed: 2396487241) | +
+
+ |
+ ![]() |
+
| (prompts used to generate) | +1girl, solo, cowboy shot | +1girl, solo, cowboy shot, ahoge, animal ears, bare shoulders, blue hair, blush, closed mouth, collarbone, collared shirt, dress, eyelashes, fox ears, fox girl, fox tail, hair between eyes, heart, long hair, long sleeves, looking at viewer, neck ribbon, ribbon, shirt, simple background, sleeves past wrists, smile, tail, white background, white dress, white shirt, yellow eyes | +
| 3girls (seed: 684589178) | +
+
+ |
+ ![]() |
+
| (prompts used to generate) | +3girls | +3girls, black footwear, black hair, black thighhighs, boots, bow, bowtie, chibi, closed mouth, collared shirt, flower, grey hair, hair between eyes, hair flower, hair ornament, long hair, long sleeves, looking at viewer, multiple girls, purple eyes, red eyes, shirt, short hair, sitting, smile, thighhighs, vest, white shirt, white skirt | +
| no humans, scenery (seed: 3702717413) | +
+
+ |
+ ![]() |
+
| (prompts used to generate) | +no humans, scenery | +no humans, scenery, animal, animal focus, bird, blue eyes, cat, dog, flower, grass, leaf, nature, petals, shadow, sitting, star (sky), sunflower, tree | +
| 1girl, frieren, sousou no frieren + (seed: 787304393) | +
+
+ |
+ ![]() |
+
| (prompts used to generate) | +1girl, frieren, sousou no frieren | +1girl, frieren, sousou no frieren, black pantyhose, cape, closed mouth, elf, fingernails, green eyes, grey hair, hair between eyes, long hair, long sleeves, looking at viewer, pantyhose, pointy ears, simple background, skirt, solo, twintails, white background, white skirt | +
| Input prompt | +Very unvaried | +Unvaried | +Normal | +Varied | +Very varied | +
| 1girl, solo, from side | +![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+
| 1girl, frieren, sousou no frieren, | +![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+
| no humans, scenery | +![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+
[Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
+and [Forge](https://github.com/lllyasviel/stable-diffusion-webui-forge)[ExifReader](https://github.com/mattiasw/ExifReader) library module to extract image metadata locally.
+
+
+
${inputs}