| |
| |
| |
| |
| |
| |
|
|
|
|
| import gc |
| import os |
| from typing import Any, Callable, Dict, List, Literal, Optional, Union |
|
|
| import numpy as np |
| import PIL |
| import torch |
| from diffusers import ( |
| FlowMatchEulerDiscreteScheduler, |
| StableDiffusion3Pipeline, |
| ) |
| from diffusers.image_processor import PipelineImageInput |
| from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import ( |
| XLA_AVAILABLE, |
| StableDiffusion3PipelineOutput, |
| calculate_shift, |
| retrieve_timesteps, |
| ) |
|
|
| from src.loss.vector_swd import VectorSWDLoss |
| from src.utils.color_space import rgb_to_lab |
| from src.utils.image import from_torch, write_img |
|
|
| if XLA_AVAILABLE: |
| from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import xm |
|
|
|
|
| def _no_grad_noise(model, *args, **kw): |
| """Forward pass with grad disabled; result is returned detached.""" |
| with torch.no_grad(): |
| return model(*args, **kw, return_dict=False)[0].detach() |
|
|
|
|
| |
| class SWStableDiffusion3Pipeline(StableDiffusion3Pipeline): |
| swd: VectorSWDLoss = None |
|
|
| def setup_swd( |
| self, |
| num_projections: int = 64, |
| use_ucv: bool = False, |
| use_lcv: bool = False, |
| distance: Literal["l1", "l2"] = "l1", |
| num_new_candidates: int = 32, |
| subsampling_factor: int = 1, |
| sampling_mode: Literal["gaussian", "qmc"] = "qmc", |
| ): |
| self.swd = VectorSWDLoss( |
| num_proj=num_projections, |
| distance=distance, |
| use_ucv=use_ucv, |
| use_lcv=use_lcv, |
| num_new_candidates=num_new_candidates, |
| missing_value_method="interpolate", |
| ess_alpha=-1, |
| sampling_mode=sampling_mode, |
| ).to(self.device) |
| self.subsampling_factor = subsampling_factor |
|
|
| def do_sw_guidance( |
| self, |
| sw_steps, |
| sw_u_lr, |
| latents, |
| t, |
| prompt_embeds, |
| pooled_prompt_embeds, |
| pixels_ref, |
| cur_iter_step, |
| write_video_animation_path: Optional[str] = None, |
| ): |
| if sw_steps == 0: |
| return latents |
|
|
| if latents.shape[0] != prompt_embeds.shape[0]: |
| prompt_embeds = prompt_embeds[1].unsqueeze(0) |
| pooled_prompt_embeds = pooled_prompt_embeds[1].unsqueeze(0) |
|
|
| |
| timestep = t.expand(latents.shape[0]) |
|
|
| pixels_ref = ( |
| rgb_to_lab(pixels_ref.unsqueeze(0).clamp(0, 1).permute(0, 3, 1, 2)) |
| .permute(0, 2, 3, 1) |
| .contiguous() |
| ) |
|
|
| csc_scaler = torch.tensor( |
| [100, 2 * 128, 2 * 128], dtype=torch.bfloat16, device=latents.device |
| ).view(1, 3, 1) |
| csc_bias = torch.tensor( |
| [0, 0.5, 0.5], dtype=torch.bfloat16, device=latents.device |
| ).view(1, 3, 1) |
|
|
| u = torch.nn.Parameter( |
| torch.zeros_like(latents, dtype=torch.bfloat16, device=latents.device) |
| ) |
| optimizer = torch.optim.Adam([u], lr=sw_u_lr) |
|
|
| for tt in range(sw_steps): |
| optimizer.zero_grad() |
|
|
| x_hat_t = latents.detach() + u |
| noise_pred = _no_grad_noise( |
| self.transformer, |
| hidden_states=x_hat_t, |
| timestep=timestep, |
| encoder_hidden_states=prompt_embeds, |
| pooled_projections=pooled_prompt_embeds, |
| joint_attention_kwargs=self.joint_attention_kwargs, |
| ) |
|
|
| |
| sigma_t = self.scheduler.sigmas[ |
| self.scheduler.index_for_timestep(t) |
| ] |
| while sigma_t.ndim < x_hat_t.ndim: |
| sigma_t = sigma_t.unsqueeze(-1) |
| sigma_t = sigma_t.to(x_hat_t.dtype).to(latents.device) |
|
|
| x_0 = x_hat_t - sigma_t * noise_pred |
|
|
| |
| img_unscaled = self.vae.decode( |
| (x_0 / self.vae.config.scaling_factor) + self.vae.config.shift_factor, |
| return_dict=False, |
| )[0] |
| image = (img_unscaled * 0.5 + 0.5).clamp(0, 1) |
| image_matched = ( |
| rgb_to_lab(image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous() |
| ) |
| |
| pred_seq = image_matched.view(1, 3, -1) / csc_scaler + csc_bias |
| ref_seq = pixels_ref.view(1, 3, -1) / csc_scaler + csc_bias |
|
|
| |
| if self.subsampling_factor > 1: |
| pred_seq = pred_seq[..., :: self.subsampling_factor] |
| ref_seq = ref_seq[..., :: self.subsampling_factor] |
|
|
| loss = self.swd(pred=pred_seq, gt=ref_seq, step=tt) |
|
|
| loss.backward() |
| optimizer.step() |
|
|
| if write_video_animation_path is not None: |
| frame_idx = cur_iter_step * sw_steps + tt |
| write_img( |
| os.path.join(write_video_animation_path, f"{frame_idx:05d}.jpg"), |
| from_torch(img_unscaled.squeeze(0)), |
| ) |
|
|
| latents = latents.detach() + u.detach() |
|
|
| gc.collect() |
| torch.cuda.empty_cache() |
| return latents |
|
|
| def __call__( |
| self, |
| sw_reference: PIL.Image = None, |
| sw_steps: int = 8, |
| sw_u_lr: float = 0.05 * 10**3, |
| num_guided_steps: int = None, |
| |
| prompt: Union[str, List[str]] = None, |
| prompt_2: Optional[Union[str, List[str]]] = None, |
| prompt_3: Optional[Union[str, List[str]]] = None, |
| height: Optional[int] = None, |
| width: Optional[int] = None, |
| num_inference_steps: int = 28, |
| sigmas: Optional[List[float]] = None, |
| guidance_scale: float = 7.0, |
| cfg_rescale_phi: float = 0.7, |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| negative_prompt_2: Optional[Union[str, List[str]]] = None, |
| negative_prompt_3: Optional[Union[str, List[str]]] = None, |
| num_images_per_prompt: Optional[int] = 1, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| latents: Optional[torch.FloatTensor] = None, |
| prompt_embeds: Optional[torch.FloatTensor] = None, |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| ip_adapter_image: Optional[PipelineImageInput] = None, |
| ip_adapter_image_embeds: Optional[torch.Tensor] = None, |
| output_type: Optional[str] = "pil", |
| return_dict: bool = True, |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
| clip_skip: Optional[int] = None, |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| max_sequence_length: int = 256, |
| skip_guidance_layers: List[int] = None, |
| skip_layer_guidance_scale: float = 2.8, |
| skip_layer_guidance_stop: float = 0.2, |
| skip_layer_guidance_start: float = 0.01, |
| mu: Optional[float] = None, |
| write_video_animation_path: Optional[str] = None, |
| ): |
| assert self.swd is not None, "SWD not initialized" |
|
|
| height = height or self.default_sample_size * self.vae_scale_factor |
| width = width or self.default_sample_size * self.vae_scale_factor |
|
|
| |
| self.check_inputs( |
| prompt, |
| prompt_2, |
| prompt_3, |
| height, |
| width, |
| negative_prompt=negative_prompt, |
| negative_prompt_2=negative_prompt_2, |
| negative_prompt_3=negative_prompt_3, |
| prompt_embeds=prompt_embeds, |
| negative_prompt_embeds=negative_prompt_embeds, |
| pooled_prompt_embeds=pooled_prompt_embeds, |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
| callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, |
| max_sequence_length=max_sequence_length, |
| ) |
|
|
| self._guidance_scale = guidance_scale |
| self._skip_layer_guidance_scale = skip_layer_guidance_scale |
| self._clip_skip = clip_skip |
| self._joint_attention_kwargs = joint_attention_kwargs |
| self._interrupt = False |
|
|
| |
| if prompt is not None and isinstance(prompt, str): |
| batch_size = 1 |
| elif prompt is not None and isinstance(prompt, list): |
| batch_size = len(prompt) |
| else: |
| batch_size = prompt_embeds.shape[0] |
|
|
| device = self._execution_device |
|
|
| lora_scale = ( |
| self.joint_attention_kwargs.get("scale", None) |
| if self.joint_attention_kwargs is not None |
| else None |
| ) |
| ( |
| prompt_embeds, |
| negative_prompt_embeds, |
| pooled_prompt_embeds, |
| negative_pooled_prompt_embeds, |
| ) = self.encode_prompt( |
| prompt=prompt, |
| prompt_2=prompt_2, |
| prompt_3=prompt_3, |
| negative_prompt=negative_prompt, |
| negative_prompt_2=negative_prompt_2, |
| negative_prompt_3=negative_prompt_3, |
| do_classifier_free_guidance=self.do_classifier_free_guidance, |
| prompt_embeds=prompt_embeds, |
| negative_prompt_embeds=negative_prompt_embeds, |
| pooled_prompt_embeds=pooled_prompt_embeds, |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
| device=device, |
| clip_skip=self.clip_skip, |
| num_images_per_prompt=num_images_per_prompt, |
| max_sequence_length=max_sequence_length, |
| lora_scale=lora_scale, |
| ) |
|
|
| if self.do_classifier_free_guidance: |
| if skip_guidance_layers is not None: |
| original_prompt_embeds = prompt_embeds |
| original_pooled_prompt_embeds = pooled_prompt_embeds |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
| pooled_prompt_embeds = torch.cat( |
| [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 |
| ) |
|
|
| |
| num_channels_latents = self.transformer.config.in_channels |
| latents = self.prepare_latents( |
| batch_size * num_images_per_prompt, |
| num_channels_latents, |
| height, |
| width, |
| prompt_embeds.dtype, |
| device, |
| generator, |
| latents, |
| ) |
|
|
| |
| scheduler_kwargs = {} |
| if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: |
| _, _, height, width = latents.shape |
| image_seq_len = (height // self.transformer.config.patch_size) * ( |
| width // self.transformer.config.patch_size |
| ) |
| mu = calculate_shift( |
| image_seq_len, |
| self.scheduler.config.get("base_image_seq_len", 256), |
| self.scheduler.config.get("max_image_seq_len", 4096), |
| self.scheduler.config.get("base_shift", 0.5), |
| self.scheduler.config.get("max_shift", 1.16), |
| ) |
| scheduler_kwargs["mu"] = mu |
| elif mu is not None: |
| scheduler_kwargs["mu"] = mu |
| timesteps, num_inference_steps = retrieve_timesteps( |
| self.scheduler, |
| num_inference_steps, |
| device, |
| sigmas=sigmas, |
| **scheduler_kwargs, |
| ) |
| num_warmup_steps = max( |
| len(timesteps) - num_inference_steps * self.scheduler.order, 0 |
| ) |
| self._num_timesteps = len(timesteps) |
|
|
| |
| if ( |
| ip_adapter_image is not None and self.is_ip_adapter_active |
| ) or ip_adapter_image_embeds is not None: |
| ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( |
| ip_adapter_image, |
| ip_adapter_image_embeds, |
| device, |
| batch_size * num_images_per_prompt, |
| self.do_classifier_free_guidance, |
| ) |
|
|
| if self.joint_attention_kwargs is None: |
| self._joint_attention_kwargs = { |
| "ip_adapter_image_embeds": ip_adapter_image_embeds |
| } |
| else: |
| self._joint_attention_kwargs.update( |
| ip_adapter_image_embeds=ip_adapter_image_embeds |
| ) |
|
|
| if sw_reference is not None: |
| |
|
|
| target_max_size = max(height, width) |
| reference_max_size = max(sw_reference.width, sw_reference.height) |
| scale_factor = target_max_size / reference_max_size |
|
|
| sw_reference = sw_reference.resize( |
| ( |
| int(sw_reference.width * scale_factor), |
| int(sw_reference.height * scale_factor), |
| ) |
| ) |
| pixels_ref = ( |
| torch.Tensor(np.array(sw_reference).astype(np.float32) / 255) |
| .permute(2, 0, 1) |
| .to(device) |
| .to(torch.bfloat16) |
| ) |
|
|
| |
| with self.progress_bar(total=num_inference_steps) as progress_bar: |
| for i, t in enumerate(timesteps): |
| if self.interrupt: |
| continue |
|
|
| |
| |
| timestep = t.expand(latents.shape[0]) |
|
|
| |
| if sw_reference is not None: |
| if num_guided_steps is None or i < num_guided_steps: |
| latents = self.do_sw_guidance( |
| sw_steps, |
| sw_u_lr, |
| latents, |
| t, |
| prompt_embeds, |
| pooled_prompt_embeds, |
| pixels_ref, |
| cur_iter_step=i, |
| write_video_animation_path=write_video_animation_path, |
| ) |
| if i == num_guided_steps // 2: |
| self.swd.reset() |
|
|
| |
| latent_model_input = ( |
| torch.cat([latents] * 2) |
| if self.do_classifier_free_guidance |
| else latents |
| ) |
|
|
| with torch.no_grad(): |
| timestep = t.expand(latent_model_input.shape[0]) |
|
|
| noise_pred = self.transformer( |
| hidden_states=latent_model_input, |
| timestep=timestep, |
| encoder_hidden_states=prompt_embeds, |
| pooled_projections=pooled_prompt_embeds, |
| joint_attention_kwargs=self.joint_attention_kwargs, |
| return_dict=False, |
| )[0] |
|
|
| |
| if self.do_classifier_free_guidance: |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + self.guidance_scale * ( |
| noise_pred_text - noise_pred_uncond |
| ) |
|
|
| should_skip_layers = ( |
| True |
| if i > num_inference_steps * skip_layer_guidance_start |
| and i < num_inference_steps * skip_layer_guidance_stop |
| else False |
| ) |
| if skip_guidance_layers is not None and should_skip_layers: |
| timestep = t.expand(latents.shape[0]) |
| latent_model_input = latents |
| noise_pred_skip_layers = self.transformer( |
| hidden_states=latent_model_input, |
| timestep=timestep, |
| encoder_hidden_states=original_prompt_embeds, |
| pooled_projections=original_pooled_prompt_embeds, |
| joint_attention_kwargs=self.joint_attention_kwargs, |
| return_dict=False, |
| skip_layers=skip_guidance_layers, |
| )[0] |
| noise_pred = ( |
| noise_pred |
| + (noise_pred_text - noise_pred_skip_layers) |
| * self._skip_layer_guidance_scale |
| ) |
|
|
| |
| |
| |
| |
| |
| if cfg_rescale_phi is not None and cfg_rescale_phi > 0: |
| |
| sigma_pos = noise_pred_text.std(dim=(1, 2, 3), keepdim=True) |
| sigma_cfg = noise_pred.std(dim=(1, 2, 3), keepdim=True) |
|
|
| |
| |
| factor = torch.lerp( |
| sigma_pos / (sigma_cfg + 1e-8), |
| torch.ones_like(sigma_cfg), |
| 1.0 - cfg_rescale_phi, |
| ) |
| noise_pred = noise_pred * factor |
| else: |
| noise_pred = noise_pred |
|
|
| |
| latents_dtype = latents.dtype |
| latents = self.scheduler.step( |
| noise_pred, t, latents, return_dict=False |
| )[0] |
|
|
| if latents.dtype != latents_dtype: |
| if torch.backends.mps.is_available(): |
| |
| |
| latents = latents.to(latents_dtype) |
|
|
| if callback_on_step_end is not None: |
| callback_kwargs = {} |
| for k in callback_on_step_end_tensor_inputs: |
| callback_kwargs[k] = locals()[k] |
| callback_outputs = callback_on_step_end( |
| self, i, t, callback_kwargs |
| ) |
|
|
| latents = callback_outputs.pop("latents", latents) |
| prompt_embeds = callback_outputs.pop( |
| "prompt_embeds", prompt_embeds |
| ) |
| negative_prompt_embeds = callback_outputs.pop( |
| "negative_prompt_embeds", negative_prompt_embeds |
| ) |
| negative_pooled_prompt_embeds = callback_outputs.pop( |
| "negative_pooled_prompt_embeds", |
| negative_pooled_prompt_embeds, |
| ) |
|
|
| if write_video_animation_path is not None and i >= num_guided_steps: |
| with torch.no_grad(): |
| image = self.vae.decode( |
| (latents / self.vae.config.scaling_factor) |
| + self.vae.config.shift_factor, |
| return_dict=False, |
| )[0] |
| cur_frame_idx = i * sw_steps |
| write_img( |
| os.path.join( |
| write_video_animation_path, |
| f"{cur_frame_idx:05d}.jpg", |
| ), |
| from_torch(image.squeeze(0)), |
| ) |
|
|
| |
| if i == len(timesteps) - 1 or ( |
| (i + 1) > num_warmup_steps |
| and (i + 1) % self.scheduler.order == 0 |
| ): |
| progress_bar.update() |
|
|
| if XLA_AVAILABLE: |
| xm.mark_step() |
|
|
| if output_type == "latent": |
| image = latents |
|
|
| else: |
| latents = ( |
| latents / self.vae.config.scaling_factor |
| ) + self.vae.config.shift_factor |
|
|
| image = self.vae.decode(latents, return_dict=False)[0] |
| image = self.image_processor.postprocess( |
| image.detach(), output_type=output_type |
| ) |
|
|
| |
| self.maybe_free_model_hooks() |
|
|
| if not return_dict: |
| return (image,) |
|
|
| return StableDiffusion3PipelineOutput(images=image) |
|
|
|
|
| def run( |
| prompt: str, |
| reference_image: PIL.Image.Image, |
| model_path: str, |
| num_inference_steps: int = 30, |
| num_guided_steps: int = 28, |
| guidance_scale: float = 5.0, |
| cfg_rescale_phi: float = 0.7, |
| sw_u_lr: float = 3e-3, |
| sw_steps: int = 8, |
| height: int = 768, |
| width: int = 768, |
| device: str = "cuda", |
| seed: Optional[int] = None, |
| |
| num_projections: int = 64, |
| use_ucv: bool = False, |
| use_lcv: bool = False, |
| distance: Literal["l1", "l2"] = "l1", |
| num_new_candidates: int = 32, |
| subsampling_factor: int = 1, |
| sampling_mode: Literal["gaussian", "qmc"] = "gaussian", |
| pipe: Optional[SWStableDiffusion3Pipeline] = None, |
| compile: bool = False, |
| video_animation_path: Optional[str] = None, |
| ) -> PIL.Image.Image: |
| """ |
| Generate an image using SW Guidance with a given prompt and reference image. |
| |
| Args: |
| prompt (str): Text prompt to guide the generation |
| reference_image (PIL.Image.Image): Reference image to guide the generation |
| model_path (str): Path to the model weights |
| num_inference_steps (int): Number of denoising steps |
| num_guided_steps (int): Number of steps to apply SW guidance |
| guidance_scale (float): Scale for classifier-free guidance |
| cfg_rescale_phi (float): Rescale factor for classifier-free guidance |
| sw_u_lr (float): Learning rate for SW guidance |
| sw_steps (int): Number of steps to apply SW guidance |
| height (int): Output image height |
| width (int): Output image width |
| device (str): Device to run the model on |
| num_projections (int): Number of random projections for VectorSWDLoss |
| use_ucv (bool): Use UCV variant of VectorSWDLoss |
| use_lcv (bool): Use LCV variant of VectorSWDLoss |
| distance (str): Distance metric for VectorSWDLoss ("l1" or "l2") |
| refresh_projections_every_n_steps (int): How often to refresh projections |
| num_new_candidates (int): Number of new candidates for the reservoir |
| subsampling_factor (int): Factor to subsample points for SW computation. |
| Higher values reduce memory usage but may affect quality. |
| sampling_mode (str): Sampling mode for VectorSWDLoss. |
| pipe (SWStableDiffusion3Pipeline): Pipeline to use for generation. |
| If None, a new pipeline is created. |
| compile (bool): Whether to compile the pipeline. |
| |
| Returns: |
| PIL.Image.Image: Generated image |
| """ |
| |
| device = torch.device(device) if not isinstance(device, torch.device) else device |
| if pipe is None: |
| pipe = create_pipeline(model_path, device, compile=compile) |
|
|
| pipe.setup_swd( |
| num_projections=num_projections, |
| use_ucv=use_ucv, |
| use_lcv=use_lcv, |
| distance=distance, |
| num_new_candidates=num_new_candidates, |
| subsampling_factor=subsampling_factor, |
| sampling_mode=sampling_mode, |
| ) |
|
|
| if seed is not None: |
| print(f"Using seed: {seed}") |
| generator = torch.Generator(device=device).manual_seed(seed) |
| else: |
| generator = None |
|
|
| image = pipe( |
| prompt=prompt, |
| num_inference_steps=num_inference_steps, |
| num_guided_steps=num_guided_steps, |
| guidance_scale=guidance_scale, |
| cfg_rescale_phi=cfg_rescale_phi, |
| sw_u_lr=sw_u_lr, |
| sw_steps=sw_steps, |
| height=height, |
| width=width, |
| sw_reference=reference_image, |
| generator=generator, |
| write_video_animation_path=video_animation_path, |
| ).images[0] |
|
|
| return image |
|
|
|
|
| def create_pipeline(model_path, device: str = "cuda", compile: bool = False): |
| pipe = SWStableDiffusion3Pipeline.from_pretrained( |
| model_path, |
| torch_dtype=torch.bfloat16, |
| ) |
| pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) |
| pipe.to(device) |
| if compile: |
| pipe.transformer = torch.compile(pipe.transformer) |
| pipe.vae.decoder = torch.compile(pipe.vae.decoder) |
| return pipe |
|
|