import torch from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn from ldm_patched.modules.samplers import sampling_function from ldm_patched.modules import model_management from ldm_patched.modules.ops import cleanup_cache def cond_from_a1111_to_patched_ldm(cond): if isinstance(cond, torch.Tensor): result = dict( cross_attn=cond, model_conds=dict( c_crossattn=CONDCrossAttn(cond), ) ) return [result, ] cross_attn = cond['crossattn'] pooled_output = cond['vector'] result = dict( cross_attn=cross_attn, pooled_output=pooled_output, model_conds=dict( c_crossattn=CONDCrossAttn(cross_attn), y=CONDRegular(pooled_output) ) ) return [result, ] def cond_from_a1111_to_patched_ldm_weighted(cond, weights): transposed = list(map(list, zip(*weights))) results = [] for cond_pre in transposed: current_indices = [] current_weight = 0 for i, w in cond_pre: current_indices.append(i) current_weight = w if hasattr(cond, 'advanced_indexing'): feed = cond.advanced_indexing(current_indices) else: feed = cond[current_indices] h = cond_from_a1111_to_patched_ldm(feed) h[0]['strength'] = current_weight results += h return results def forge_sample(self, denoiser_params, cond_scale, cond_composition): model = self.inner_model.inner_model.forge_objects.unet.model control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition x = denoiser_params.x timestep = denoiser_params.sigma uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond) cond = cond_from_a1111_to_patched_ldm_weighted(denoiser_params.text_cond, cond_composition) model_options = self.inner_model.inner_model.forge_objects.unet.model_options seed = self.p.seeds[0] if extra_concat_condition is not None: image_cond_in = extra_concat_condition else: image_cond_in = denoiser_params.image_cond if isinstance(image_cond_in, torch.Tensor): if image_cond_in.shape[0] == x.shape[0] \ and image_cond_in.shape[2] == x.shape[2] \ and image_cond_in.shape[3] == x.shape[3]: for i in range(len(uncond)): uncond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in) for i in range(len(cond)): cond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in) if control is not None: for h in cond + uncond: h['control'] = control for modifier in model_options.get('conditioning_modifiers', []): model, x, timestep, uncond, cond, cond_scale, model_options, seed = modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed) denoised = sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed) return denoised def sampling_prepare(unet, x): B, C, H, W = x.shape memory_estimation_function = unet.model_options.get('memory_peak_estimation_modifier', unet.memory_required) unet_inference_memory = memory_estimation_function([B * 2, C, H, W]) additional_inference_memory = unet.extra_preserved_memory_during_sampling additional_model_patchers = unet.extra_model_patchers_during_sampling if unet.controlnet_linked_list is not None: additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype()) additional_model_patchers += unet.controlnet_linked_list.get_models() model_management.load_models_gpu( models=[unet] + additional_model_patchers, memory_required=unet_inference_memory + additional_inference_memory) real_model = unet.model percent_to_timestep_function = lambda p: real_model.model_sampling.percent_to_sigma(p) for cnet in unet.list_controlnets(): cnet.pre_run(real_model, percent_to_timestep_function) return def sampling_cleanup(unet): for cnet in unet.list_controlnets(): cnet.cleanup() cleanup_cache() return