Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from typing import Union, List, Optional | |
| import PIL.Image | |
| import numpy as np | |
| from tqdm.auto import trange | |
| from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import * | |
| from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline | |
| from diffusers import DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerDiscreteScheduler | |
| from diffusers.utils.outputs import BaseOutput | |
| from modules.layerdiffuse.vae import TransparentVAEDecoder, TransparentVAEEncoder, vae_encode | |
| from .layerdiff3d import UNetFrameConditionModel | |
| from utils.torch_utils import seed_everything, img2tensor, tensor2img | |
| class LayerdiffPipelineOutput(BaseOutput): | |
| """ | |
| Output class for Stable Diffusion pipelines. | |
| Args: | |
| images (`List[PIL.Image.Image]` or `np.ndarray`) | |
| List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, | |
| num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. | |
| """ | |
| images: Union[List[PIL.Image.Image], np.ndarray] | |
| vis_list: Union[List[PIL.Image.Image], np.ndarray] | |
| def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, show_progress=True, c_concat=None): | |
| """DPM-Solver++(2M).""" | |
| extra_args = {} if extra_args is None else extra_args | |
| s_in = x.new_ones([x.shape[0]]) | |
| sigma_fn = lambda t: t.neg().exp() | |
| t_fn = lambda sigma: sigma.log().neg() | |
| old_denoised = None | |
| for i in trange(len(sigmas) - 1, disable=not show_progress): | |
| model_input = x | |
| denoised = model(model_input, sigmas[i] * s_in, c_concat=c_concat, **extra_args) | |
| if callback is not None: | |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
| t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) | |
| h = t_next - t | |
| if old_denoised is None or sigmas[i + 1] == 0: | |
| x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised | |
| else: | |
| h_last = t - t_fn(sigmas[i - 1]) | |
| r = h_last / h | |
| denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised | |
| x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d | |
| old_denoised = denoised | |
| return x | |
| class KDiffusionStableDiffusionXLPipeline(StableDiffusionXLImg2ImgPipeline): | |
| _optional_components = [ | |
| "tokenizer", | |
| "tokenizer_2", | |
| "text_encoder", | |
| "text_encoder_2", | |
| "image_encoder", | |
| "feature_extractor", | |
| ] | |
| def __init__(self, | |
| vae, | |
| text_encoder, | |
| tokenizer, | |
| text_encoder_2, | |
| tokenizer_2, | |
| unet, | |
| scheduler=None, | |
| trans_vae=None, | |
| tag_list=None, | |
| image_encoder: CLIPVisionModelWithProjection = None, | |
| feature_extractor: CLIPImageProcessor = None, | |
| requires_aesthetics_score: bool = False, | |
| force_zeros_for_empty_prompt: bool = True, | |
| add_watermarker: Optional[bool] = None, | |
| ): | |
| if scheduler is None: | |
| config_min = {"final_sigmas_type":"sigma_min"} | |
| config_min_euler = {"final_sigmas_type":"sigma_min", "euler_at_final": True } | |
| config_zero = {"final_sigmas_type":"zero"} | |
| schedulers = { | |
| "DPMPP_2M": { | |
| "min": (DPMSolverMultistepScheduler, config_min), | |
| "min_euler": (DPMSolverMultistepScheduler, config_min_euler), | |
| "zero": (DPMSolverMultistepScheduler, config_zero), | |
| }, | |
| "DPMPP_2M_K": { | |
| "min": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, **config_min}), | |
| "min_euler": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, **config_min_euler}), | |
| "zero": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, **config_zero}), | |
| }, | |
| "DPMPP_2M_SDE": { | |
| "min": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_min}), | |
| "min_euler": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_min_euler}), | |
| "zero": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_zero}), | |
| }, | |
| "DPMPP_2M_SDE_K": { | |
| "min": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True, **config_min}), | |
| "min_euler": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True, **config_min_euler}), | |
| "zero": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++", **config_zero}), | |
| }, | |
| "DPMPP": { | |
| "min": (DPMSolverSinglestepScheduler, config_min), | |
| "min_euler": (DPMSolverSinglestepScheduler, config_min_euler), | |
| "zero": (DPMSolverSinglestepScheduler, config_zero), | |
| }, | |
| "DPMPP_K": { | |
| "min": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, **config_min}), | |
| "min_euler": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, **config_min_euler}), | |
| "zero": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, **config_zero}), | |
| }, | |
| } | |
| model_id = "frankjoshua/juggernautXL_version6Rundiffusion" | |
| scheduler_name = "DPMPP_2M_SDE" | |
| scheduler_config_name = "zero" | |
| scheduler_configs = schedulers[scheduler_name] | |
| scheduler = scheduler_configs[scheduler_config_name][0].from_pretrained( | |
| model_id, | |
| subfolder="scheduler", | |
| **scheduler_configs[scheduler_config_name][1], | |
| ) | |
| super().__init__( | |
| vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, | |
| unet=unet, scheduler=scheduler,feature_extractor=feature_extractor, image_encoder=image_encoder, requires_aesthetics_score=requires_aesthetics_score, | |
| force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, add_watermarker=add_watermarker) | |
| # self.register_to_config(tag_list=tag_list) | |
| self.register_modules(trans_vae=trans_vae) | |
| def do_classifier_free_guidance(self): | |
| return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None | |
| def encode_cropped_prompt_77tokens(self, prompt: str): | |
| device = self.text_encoder.device | |
| tokenizers = [self.tokenizer, self.tokenizer_2] | |
| text_encoders = [self.text_encoder, self.text_encoder_2] | |
| pooled_prompt_embeds = None | |
| prompt_embeds_list = [] | |
| for tokenizer, text_encoder in zip(tokenizers, text_encoders): | |
| text_input_ids = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ).input_ids | |
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True, return_dict=False) | |
| # We are only ALWAYS interested in the pooled output of the final text encoder | |
| pooled_prompt_embeds = prompt_embeds[0] | |
| prompt_embeds = prompt_embeds[-1][-2] | |
| bs_embed, seq_len, _ = prompt_embeds.shape | |
| prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) | |
| prompt_embeds_list.append(prompt_embeds) | |
| prompt_embeds = torch.concat(prompt_embeds_list, dim=-1).to(dtype=self.unet.dtype, device=device) | |
| pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) | |
| # prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) | |
| return prompt_embeds, pooled_prompt_embeds | |
| def denoise_func(self, latents, add_text_embeds, add_time_ids, prompt_embeds, c_concat, num_inference_steps=50): | |
| # 4. Prepare timesteps | |
| device = self.unet.device | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| self.scheduler, num_inference_steps, device, timesteps=None, sigmas=None | |
| ) | |
| latents = latents * self.scheduler.init_noise_sigma | |
| for i, t in enumerate(timesteps): | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| # predict the noise residual | |
| added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} | |
| noise_pred = self.unet( | |
| torch.cat([latent_model_input, c_concat], dim=-3), | |
| t, | |
| encoder_hidden_states=prompt_embeds, | |
| added_cond_kwargs=added_cond_kwargs, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: | |
| # Based on 3.4. in https://huggingface.co/papers/2305.08891 | |
| noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents_dtype = latents.dtype | |
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| if latents.dtype != latents_dtype: | |
| if torch.backends.mps.is_available(): | |
| # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 | |
| latents = latents.to(latents_dtype) | |
| return latents | |
| def __call__( | |
| self, | |
| initial_latent: torch.FloatTensor = None, | |
| strength: float = 1.0, | |
| num_inference_steps: int = 25, | |
| guidance_scale: float = 5.0, | |
| batch_size: Optional[int] = 1, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| c_concat=None, | |
| prompt=None, | |
| negative_prompt=None, | |
| show_progress=True, | |
| fullpage=None, | |
| group_index=None | |
| ): | |
| device = self.unet.device | |
| dtype = self.unet.dtype | |
| if fullpage is not None: | |
| page_alpha = img2tensor(fullpage[..., -1] / 255., device=self.vae.device, dtype=self.vae.dtype)[0][..., None] | |
| fullpage = fullpage[..., :3] | |
| c_concat = np.concatenate([np.full_like(fullpage[..., :1], fill_value=255), fullpage], axis=2) | |
| c_concat = img2tensor(c_concat, normalize=True) | |
| c_concat = vae_encode(self.vae, self.trans_vae.encoder, c_concat, use_offset=False).to(device=device, dtype=dtype) | |
| c_concat = c_concat.to(dtype=dtype) | |
| assert c_concat is not None | |
| self._guidance_scale = guidance_scale | |
| is_3d = isinstance(self.unet, UNetFrameConditionModel) | |
| lh, lw = c_concat.shape[-2:] | |
| num_frames = 1 | |
| if is_3d: | |
| if prompt is not None: | |
| num_frames = len(prompt) | |
| if prompt_embeds is not None: | |
| num_frames = len(prompt_embeds) | |
| if initial_latent is None: | |
| initial_latent = torch.zeros((batch_size, 4, lh, lw), device=self.unet.device, dtype=self.unet.dtype) | |
| if is_3d and c_concat.ndim == 4: | |
| c_concat = c_concat[:, None].expand(-1, num_frames, -1, -1, -1) | |
| if is_3d and initial_latent.ndim == 4: | |
| initial_latent = initial_latent[:, None].expand(-1, num_frames, -1, -1, -1) | |
| if prompt is not None: | |
| prompt_embeds, pooled_prompt_embeds = self.encode_cropped_prompt_77tokens(prompt) | |
| if negative_prompt is not None and self.do_classifier_free_guidance: | |
| negative_prompt_embeds, negative_pooled_prompt_embeds = self.encode_cropped_prompt_77tokens(negative_prompt) | |
| # Initial latents | |
| # noise = randn_tensor(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype) | |
| noise = randn_tensor(initial_latent[:, [0]].shape, generator=generator, device=device, dtype=self.unet.dtype).expand(-1, num_frames, -1, -1, -1) | |
| # latents = initial_latent.to(noise) + noise * sigmas[0].to(noise) | |
| height = lh * self.vae_scale_factor | |
| width = lw * self.vae_scale_factor | |
| add_time_ids = list((height, width) + (0, 0) + (height, width)) | |
| add_time_ids = torch.tensor([add_time_ids], dtype=self.unet.dtype) | |
| add_time_ids = add_time_ids.expand((prompt_embeds.shape[0], -1)) | |
| add_neg_time_ids = add_time_ids.clone() | |
| # Batch | |
| # latents = latents.to(device) | |
| add_time_ids = add_time_ids.repeat(batch_size, 1).to(device) | |
| add_neg_time_ids = add_neg_time_ids.repeat(batch_size, 1).to(device) | |
| prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1).to(device) | |
| pooled_prompt_embeds = pooled_prompt_embeds.repeat(batch_size, 1).to(device) | |
| sampler_kwargs = dict( | |
| cfg_scale=guidance_scale, | |
| positive=dict( | |
| encoder_hidden_states=prompt_embeds, | |
| added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},) | |
| ) | |
| if negative_prompt_embeds is not None: | |
| negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1).to(device) | |
| negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1).to(device) | |
| sampler_kwargs['negative'] = dict( | |
| encoder_hidden_states=negative_prompt_embeds, | |
| added_cond_kwargs={"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids}, | |
| ) | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| self.scheduler, num_inference_steps, device, timesteps=None, sigmas=None | |
| ) | |
| latents = noise * self.scheduler.init_noise_sigma | |
| for i, t in enumerate(timesteps): | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| # predict the noise residual | |
| added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids} | |
| noise_pred = self.unet( | |
| torch.cat([latent_model_input, c_concat], dim=-3), | |
| t, | |
| encoder_hidden_states=prompt_embeds, | |
| added_cond_kwargs=added_cond_kwargs, | |
| return_dict=False, | |
| group_index=group_index | |
| )[0] | |
| # perform guidance | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: | |
| # Based on 3.4. in https://huggingface.co/papers/2305.08891 | |
| noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents_dtype = latents.dtype | |
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| if latents.dtype != latents_dtype: | |
| if torch.backends.mps.is_available(): | |
| # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 | |
| latents = latents.to(latents_dtype) | |
| if latents.ndim == 5: | |
| latents = latents[0] | |
| if self.trans_vae is None: | |
| return latents | |
| latents = latents.to(dtype=self.trans_vae.dtype, device=self.trans_vae.device) / self.vae.config.scaling_factor | |
| vis_list = [] | |
| res_list = [] | |
| for latent in latents: | |
| latent = latent[None] | |
| # latent = scheduler.add_noise(latent, torch.randn_like(latent), timesteps=torch.tensor([1], device=latent.device)) | |
| result_list, vis_list_batch = self.trans_vae.decoder(self.vae, latent, mask=page_alpha) | |
| vis_list += vis_list_batch | |
| res_list += result_list | |
| return LayerdiffPipelineOutput(images=res_list, vis_list=vis_list) | |