| | import modules.scripts as scripts |
| | import gradio as gr |
| |
|
| | import io |
| | import json |
| | import matplotlib.pyplot as plt |
| | from PIL import Image |
| | import numpy as np |
| | import inspect |
| | import torch |
| | from modules import prompt_parser, devices, sd_samplers_common |
| | import re |
| | from modules.shared import opts, state |
| | import modules.shared as shared |
| | from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback |
| | from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback |
| | from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback |
| |
|
| | from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser |
| | from modules.sd_samplers_timesteps import CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser |
| | from modules.sd_samplers_cfg_denoiser import catenate_conds, subscript_cond, pad_cond |
| | from modules import script_callbacks |
| |
|
| | import k_diffusion.utils as utils_old |
| |
|
| | try: |
| | from modules_forge import forge_sampler |
| | from modules_forge.forge_sampler import * |
| | from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn |
| | from ldm_patched.modules import model_management |
| | from ldm_patched.modules.ops import cleanup_cache |
| | from ldm_patched.modules.samplers import * |
| | isForge = True |
| | except Exception: |
| | isForge = False |
| |
|
| | from scripts.CharaIte import Chara_iteration |
| |
|
| | |
| | |
| |
|
| |
|
| | print("**********Read forge sample code *********") |
| |
|
| | def calc_cond_uncond_batch(self,model, cond, uncond, x_in, timestep, model_options,cond_scale): |
| | |
| | out_cond = torch.zeros_like(x_in) |
| | out_count = torch.ones_like(x_in) * 1e-37 |
| |
|
| | out_uncond = torch.zeros_like(x_in) |
| | out_uncond_count = torch.ones_like(x_in) * 1e-37 |
| |
|
| | COND = 0 |
| | UNCOND = 1 |
| |
|
| | to_run = [] |
| | for x in cond: |
| | p = get_area_and_mult(x, x_in, timestep) |
| | if p is None: |
| | continue |
| |
|
| | to_run += [(p, COND)] |
| | if uncond is not None: |
| | for x in uncond: |
| | p = get_area_and_mult(x, x_in, timestep) |
| | if p is None: |
| | continue |
| |
|
| | to_run += [(p, UNCOND)] |
| |
|
| | while len(to_run) > 0: |
| | first = to_run[0] |
| | first_shape = first[0][0].shape |
| | to_batch_temp = [] |
| | for x in range(len(to_run)): |
| | if can_concat_cond(to_run[x][0], first[0]): |
| | to_batch_temp += [x] |
| |
|
| | to_batch_temp.reverse() |
| | to_batch = to_batch_temp[:1] |
| |
|
| | free_memory = model_management.get_free_memory(x_in.device) |
| | for i in range(1, len(to_batch_temp) + 1): |
| | batch_amount = to_batch_temp[:len(to_batch_temp)//i] |
| | input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] |
| | if model.memory_required(input_shape) < free_memory: |
| | to_batch = batch_amount |
| | break |
| |
|
| | input_x = [] |
| | mult = [] |
| | c = [] |
| | cond_or_uncond = [] |
| | area = [] |
| | control = None |
| | patches = None |
| | for x in to_batch: |
| | o = to_run.pop(x) |
| | p = o[0] |
| | |
| | input_x.append(p.input_x) |
| | mult.append(p.mult) |
| | c.append(p.conditioning) |
| | area.append(p.area) |
| | cond_or_uncond.append(o[1]) |
| | control = p.control |
| | patches = p.patches |
| |
|
| | batch_chunks = len(cond_or_uncond) |
| | input_x = torch.cat(input_x) |
| | c = cond_cat(c) |
| | timestep_ = torch.cat([timestep] * batch_chunks) |
| |
|
| | transformer_options = {} |
| | if 'transformer_options' in model_options: |
| | transformer_options = model_options['transformer_options'].copy() |
| |
|
| | if patches is not None: |
| | if "patches" in transformer_options: |
| | cur_patches = transformer_options["patches"].copy() |
| | for p in patches: |
| | if p in cur_patches: |
| | cur_patches[p] = cur_patches[p] + patches[p] |
| | else: |
| | cur_patches[p] = patches[p] |
| | else: |
| | transformer_options["patches"] = patches |
| |
|
| | transformer_options["cond_or_uncond"] = cond_or_uncond[:] |
| | transformer_options["sigmas"] = timestep |
| |
|
| | transformer_options["cond_mark"] = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep) |
| | transformer_options["cond_indices"], transformer_options["uncond_indices"] = compute_cond_indices(cond_or_uncond=cond_or_uncond, sigmas=timestep) |
| |
|
| | c['transformer_options'] = transformer_options |
| |
|
| | if control is not None: |
| | print('control is running') |
| | p = control |
| | while p is not None: |
| | p.transformer_options = transformer_options |
| | p = p.previous_controlnet |
| | control_cond = c.copy() |
| | c['control'] = control.get_control(input_x, timestep_, control_cond, len(cond_or_uncond)) |
| | c['control_model'] = control |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if 'model_function_wrapper' in model_options: |
| | output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) |
| | else: |
| | |
| | |
| | output = Chara_iteration(self,model,None,input_x,timestep_,cond_scale,uncond[0]['cross_attn'],c).chunk(batch_chunks) |
| | |
| | |
| | del input_x |
| | |
| | |
| | |
| | |
| | for o in range(batch_chunks): |
| | |
| | if cond_or_uncond[o] == COND: |
| | out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] |
| | out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] |
| | else: |
| | out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] |
| | out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] |
| | del mult |
| |
|
| | out_cond /= out_count |
| | del out_count |
| | out_uncond /= out_uncond_count |
| | del out_uncond_count |
| | return out_cond, out_uncond |
| | def sampling_function(self,model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): |
| | |
| | edit_strength = sum((item['strength'] if 'strength' in item else 1) for item in cond) |
| | if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: |
| | uncond_ = None |
| | else: |
| | uncond_ = uncond |
| |
|
| |
|
| | for fn in model_options.get("sampler_pre_cfg_function", []): |
| | |
| | model, cond, uncond_, x, timestep, model_options = fn(model, cond, uncond_, x, timestep, model_options) |
| | |
| | |
| | cond_pred, uncond_pred = calc_cond_uncond_batch(self,model, cond, uncond_, x, timestep, model_options,cond_scale) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if "sampler_cfg_function" in model_options: |
| | args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, |
| | "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} |
| | cfg_result = x - model_options["sampler_cfg_function"](args) |
| | elif not math.isclose(edit_strength, 1.0): |
| | cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale * edit_strength |
| | else: |
| | cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale |
| |
|
| | for fn in model_options.get("sampler_post_cfg_function", []): |
| | args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, |
| | "sigma": timestep, "model_options": model_options, "input": x} |
| | cfg_result = fn(args) |
| | print("**********CHG Sampling***********") |
| | return cfg_result |
| |
|
| | |
| | |
| |
|
| |
|
| | def forge_sample(self, denoiser_params, cond_scale, cond_composition): |
| | model = self.inner_model.inner_model.forge_objects.unet.model |
| | control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list |
| | extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition |
| | x = denoiser_params.x |
| | timestep = denoiser_params.sigma |
| | |
| | uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond) |
| | |
| | |
| | |
| | cond = cond_from_a1111_to_patched_ldm_weighted(denoiser_params.text_cond, cond_composition) |
| | |
| | |
| | model_options = self.inner_model.inner_model.forge_objects.unet.model_options |
| | seed = self.p.seeds[0] |
| | if extra_concat_condition is not None: |
| | image_cond_in = extra_concat_condition |
| | else: |
| | image_cond_in = denoiser_params.image_cond |
| |
|
| | if isinstance(image_cond_in, torch.Tensor): |
| | if image_cond_in.shape[0] == x.shape[0] \ |
| | and image_cond_in.shape[2] == x.shape[2] \ |
| | and image_cond_in.shape[3] == x.shape[3]: |
| | for i in range(len(uncond)): |
| | uncond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in) |
| | for i in range(len(cond)): |
| | cond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in) |
| |
|
| | if control is not None: |
| | for h in cond + uncond: |
| | h['control'] = control |
| |
|
| | for modifier in model_options.get('conditioning_modifiers', []): |
| | model, x, timestep, uncond, cond, cond_scale, model_options, seed = modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed) |
| |
|
| | denoised = sampling_function(self,model, x, timestep, uncond, cond, cond_scale, model_options, seed) |
| | return denoised |
| |
|
| |
|
| | def CHGdenoiserConstruct(): |
| | CHGDenoiserStr = ''' |
| | class CHGDenoiser(CFGDenoiser): |
| | def __init__(self, sampler): |
| | super().__init__(sampler) |
| | def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): |
| | #self.inner_model: CompVisDenoiser |
| | if state.interrupted or state.skipped: |
| | raise sd_samplers_common.InterruptedException |
| | |
| | original_x_device = x.device |
| | original_x_dtype = x.dtype |
| | acd = self.inner_model.inner_model.alphas_cumprod |
| | |
| | if self.classic_ddim_eps_estimation: |
| | acd = self.inner_model.inner_model.alphas_cumprod |
| | fake_sigmas = ((1 - acd) / acd) ** 0.5 |
| | real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))] |
| | real_sigma_data = 1.0 |
| | x = x * (((real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5)[:, None, None, None]) |
| | sigma = real_sigma |
| | |
| | if sd_samplers_common.apply_refiner(self, x): |
| | cond = self.sampler.sampler_extra_args['cond'] |
| | uncond = self.sampler.sampler_extra_args['uncond'] |
| | |
| | cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step) |
| | uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) |
| | |
| | if self.mask is not None: |
| | noisy_initial_latent = self.init_latent + sigma[:, None, None, None] * torch.randn_like(self.init_latent).to(self.init_latent) |
| | x = x * self.nmask + noisy_initial_latent * self.mask |
| | |
| | denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self) |
| | cfg_denoiser_callback(denoiser_params) |
| | |
| | denoised = forge_CHG.forge_sample(self,denoiser_params=denoiser_params, |
| | cond_scale=cond_scale, cond_composition=cond_composition) |
| | |
| | if self.mask is not None: |
| | denoised = denoised * self.nmask + self.init_latent * self.mask |
| | |
| | preview = self.sampler.last_latent = denoised |
| | sd_samplers_common.store_latent(preview) |
| | |
| | after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) |
| | cfg_after_cfg_callback(after_cfg_callback_params) |
| | denoised = after_cfg_callback_params.x |
| | |
| | self.step += 1 |
| | |
| | if self.classic_ddim_eps_estimation: |
| | eps = (x - denoised) / sigma[:, None, None, None] |
| | return eps |
| | print("*****CHG ini success*****") |
| | |
| | return denoised.to(device=original_x_device, dtype=original_x_dtype) |
| | ''' |
| | |
| | |
| | return CHGDenoiserStr |