# Implementation of the SW Guidance method with our enhanced SWD implementation # See: https://github.com/alobashev/sw-guidance/ for the original implementation # # Alexander Lobashev, Maria Larchenko, Dmitry Guskov # Color Conditional Generation with Sliced Wasserstein Guidance # https://arxiv.org/abs/2503.19034 import gc import os from typing import Any, Callable, Dict, List, Literal, Optional, Union import numpy as np import PIL import torch from diffusers import ( FlowMatchEulerDiscreteScheduler, StableDiffusion3Pipeline, ) from diffusers.image_processor import PipelineImageInput from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import ( XLA_AVAILABLE, StableDiffusion3PipelineOutput, calculate_shift, retrieve_timesteps, ) from src.loss.vector_swd import VectorSWDLoss from src.utils.color_space import rgb_to_lab from src.utils.image import from_torch, write_img if XLA_AVAILABLE: from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import xm def _no_grad_noise(model, *args, **kw): """Forward pass with grad disabled; result is returned detached.""" with torch.no_grad(): return model(*args, **kw, return_dict=False)[0].detach() # ---------------- explicit pipeline forward call class SWStableDiffusion3Pipeline(StableDiffusion3Pipeline): swd: VectorSWDLoss = None def setup_swd( self, num_projections: int = 64, use_ucv: bool = False, use_lcv: bool = False, distance: Literal["l1", "l2"] = "l1", num_new_candidates: int = 32, subsampling_factor: int = 1, sampling_mode: Literal["gaussian", "qmc"] = "qmc", ): self.swd = VectorSWDLoss( num_proj=num_projections, distance=distance, use_ucv=use_ucv, use_lcv=use_lcv, num_new_candidates=num_new_candidates, missing_value_method="interpolate", ess_alpha=-1, sampling_mode=sampling_mode, ).to(self.device) self.subsampling_factor = subsampling_factor def do_sw_guidance( self, sw_steps, sw_u_lr, latents, t, prompt_embeds, pooled_prompt_embeds, pixels_ref, cur_iter_step, write_video_animation_path: Optional[str] = None, ): if sw_steps == 0: return latents if latents.shape[0] != prompt_embeds.shape[0]: prompt_embeds = prompt_embeds[1].unsqueeze(0) pooled_prompt_embeds = pooled_prompt_embeds[1].unsqueeze(0) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) pixels_ref = ( rgb_to_lab(pixels_ref.unsqueeze(0).clamp(0, 1).permute(0, 3, 1, 2)) .permute(0, 2, 3, 1) .contiguous() ) csc_scaler = torch.tensor( [100, 2 * 128, 2 * 128], dtype=torch.bfloat16, device=latents.device ).view(1, 3, 1) csc_bias = torch.tensor( [0, 0.5, 0.5], dtype=torch.bfloat16, device=latents.device ).view(1, 3, 1) u = torch.nn.Parameter( torch.zeros_like(latents, dtype=torch.bfloat16, device=latents.device) ) optimizer = torch.optim.Adam([u], lr=sw_u_lr) for tt in range(sw_steps): optimizer.zero_grad() x_hat_t = latents.detach() + u noise_pred = _no_grad_noise( self.transformer, hidden_states=x_hat_t, timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, ) # ------------ Compute x_0 sigma_t = self.scheduler.sigmas[ self.scheduler.index_for_timestep(t) ] # scalar while sigma_t.ndim < x_hat_t.ndim: sigma_t = sigma_t.unsqueeze(-1) sigma_t = sigma_t.to(x_hat_t.dtype).to(latents.device) x_0 = x_hat_t - sigma_t * noise_pred # ------------ Compute loss img_unscaled = self.vae.decode( (x_0 / self.vae.config.scaling_factor) + self.vae.config.shift_factor, return_dict=False, )[0] image = (img_unscaled * 0.5 + 0.5).clamp(0, 1) image_matched = ( rgb_to_lab(image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous() ) # reshape to (B, D, N) where D=3, N = H*W pred_seq = image_matched.view(1, 3, -1) / csc_scaler + csc_bias ref_seq = pixels_ref.view(1, 3, -1) / csc_scaler + csc_bias # Apply subsampling if enabled if self.subsampling_factor > 1: pred_seq = pred_seq[..., :: self.subsampling_factor] ref_seq = ref_seq[..., :: self.subsampling_factor] loss = self.swd(pred=pred_seq, gt=ref_seq, step=tt) loss.backward() optimizer.step() if write_video_animation_path is not None: frame_idx = cur_iter_step * sw_steps + tt write_img( os.path.join(write_video_animation_path, f"{frame_idx:05d}.jpg"), from_torch(img_unscaled.squeeze(0)), ) latents = latents.detach() + u.detach() gc.collect() torch.cuda.empty_cache() return latents def __call__( self, sw_reference: PIL.Image = None, sw_steps: int = 8, sw_u_lr: float = 0.05 * 10**3, num_guided_steps: int = None, # ----------------------------------- prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, sigmas: Optional[List[float]] = None, guidance_scale: float = 7.0, cfg_rescale_phi: float = 0.7, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, skip_guidance_layers: List[int] = None, skip_layer_guidance_scale: float = 2.8, skip_layer_guidance_stop: float = 0.2, skip_layer_guidance_start: float = 0.01, mu: Optional[float] = None, write_video_animation_path: Optional[str] = None, ): assert self.swd is not None, "SWD not initialized" height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, prompt_3, height, width, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._skip_layer_guidance_scale = skip_layer_guidance_scale self._clip_skip = clip_skip self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) if self.do_classifier_free_guidance: if skip_guidance_layers is not None: original_prompt_embeds = prompt_embeds original_pooled_prompt_embeds = pooled_prompt_embeds prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat( [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 ) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 5. Prepare timesteps scheduler_kwargs = {} if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: _, _, height, width = latents.shape image_seq_len = (height // self.transformer.config.patch_size) * ( width // self.transformer.config.patch_size ) mu = calculate_shift( image_seq_len, self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.16), ) scheduler_kwargs["mu"] = mu elif mu is not None: scheduler_kwargs["mu"] = mu timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs, ) num_warmup_steps = max( len(timesteps) - num_inference_steps * self.scheduler.order, 0 ) self._num_timesteps = len(timesteps) # 6. Prepare image embeddings if ( ip_adapter_image is not None and self.is_ip_adapter_active ) or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) if self.joint_attention_kwargs is None: self._joint_attention_kwargs = { "ip_adapter_image_embeds": ip_adapter_image_embeds } else: self._joint_attention_kwargs.update( ip_adapter_image_embeds=ip_adapter_image_embeds ) if sw_reference is not None: # Resize so the reference is maximal width or height of the output image target_max_size = max(height, width) reference_max_size = max(sw_reference.width, sw_reference.height) scale_factor = target_max_size / reference_max_size sw_reference = sw_reference.resize( ( int(sw_reference.width * scale_factor), int(sw_reference.height * scale_factor), ) ) pixels_ref = ( torch.Tensor(np.array(sw_reference).astype(np.float32) / 255) .permute(2, 0, 1) .to(device) .to(torch.bfloat16) ) # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # broadcast to batch dimension in a way that's compatible # with ONNX/Core ML timestep = t.expand(latents.shape[0]) # SW Guidance if sw_reference is not None: if num_guided_steps is None or i < num_guided_steps: latents = self.do_sw_guidance( sw_steps, sw_u_lr, latents, t, prompt_embeds, pooled_prompt_embeds, pixels_ref, cur_iter_step=i, write_video_animation_path=write_video_animation_path, ) if i == num_guided_steps // 2: self.swd.reset() # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents ) with torch.no_grad(): timestep = t.expand(latent_model_input.shape[0]) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * ( noise_pred_text - noise_pred_uncond ) should_skip_layers = ( True if i > num_inference_steps * skip_layer_guidance_start and i < num_inference_steps * skip_layer_guidance_stop else False ) if skip_guidance_layers is not None and should_skip_layers: timestep = t.expand(latents.shape[0]) latent_model_input = latents noise_pred_skip_layers = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=original_prompt_embeds, pooled_projections=original_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, skip_layers=skip_guidance_layers, )[0] noise_pred = ( noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale ) # Based on Sec. 3.4 of Lin, Liu, Li, Yang - # Common Diffusion Noise Schedules and Sample Steps are Flawed # https://arxiv.org/abs/2305.08891 # While Flow matching is free of most issues, a high CFG scale # can still cause over-exposure issues as discussed in the work. if cfg_rescale_phi is not None and cfg_rescale_phi > 0: # σ_pos and σ_cfg are per-sample (B×1×1×1) stdevs sigma_pos = noise_pred_text.std(dim=(1, 2, 3), keepdim=True) sigma_cfg = noise_pred.std(dim=(1, 2, 3), keepdim=True) # Linear blend between the raw ratio and 1, # cf. Eq. (15–16) in the paper factor = torch.lerp( sigma_pos / (sigma_cfg + 1e-8), # avoid div-by-zero torch.ones_like(sigma_cfg), 1.0 - cfg_rescale_phi, ) noise_pred = noise_pred * factor else: noise_pred = noise_pred # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step( noise_pred, t, latents, return_dict=False )[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a # pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end( self, i, t, callback_kwargs ) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop( "prompt_embeds", prompt_embeds ) negative_prompt_embeds = callback_outputs.pop( "negative_prompt_embeds", negative_prompt_embeds ) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds, ) if write_video_animation_path is not None and i >= num_guided_steps: with torch.no_grad(): image = self.vae.decode( (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor, return_dict=False, )[0] cur_frame_idx = i * sw_steps write_img( os.path.join( write_video_animation_path, f"{cur_frame_idx:05d}.jpg", ), from_torch(image.squeeze(0)), ) # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent": image = latents else: latents = ( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess( image.detach(), output_type=output_type ) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusion3PipelineOutput(images=image) def run( prompt: str, reference_image: PIL.Image.Image, model_path: str, num_inference_steps: int = 30, num_guided_steps: int = 28, guidance_scale: float = 5.0, cfg_rescale_phi: float = 0.7, sw_u_lr: float = 3e-3, sw_steps: int = 8, height: int = 768, width: int = 768, device: str = "cuda", seed: Optional[int] = None, # Add new SW-related parameters num_projections: int = 64, use_ucv: bool = False, use_lcv: bool = False, distance: Literal["l1", "l2"] = "l1", num_new_candidates: int = 32, subsampling_factor: int = 1, sampling_mode: Literal["gaussian", "qmc"] = "gaussian", pipe: Optional[SWStableDiffusion3Pipeline] = None, compile: bool = False, video_animation_path: Optional[str] = None, ) -> PIL.Image.Image: """ Generate an image using SW Guidance with a given prompt and reference image. Args: prompt (str): Text prompt to guide the generation reference_image (PIL.Image.Image): Reference image to guide the generation model_path (str): Path to the model weights num_inference_steps (int): Number of denoising steps num_guided_steps (int): Number of steps to apply SW guidance guidance_scale (float): Scale for classifier-free guidance cfg_rescale_phi (float): Rescale factor for classifier-free guidance sw_u_lr (float): Learning rate for SW guidance sw_steps (int): Number of steps to apply SW guidance height (int): Output image height width (int): Output image width device (str): Device to run the model on num_projections (int): Number of random projections for VectorSWDLoss use_ucv (bool): Use UCV variant of VectorSWDLoss use_lcv (bool): Use LCV variant of VectorSWDLoss distance (str): Distance metric for VectorSWDLoss ("l1" or "l2") refresh_projections_every_n_steps (int): How often to refresh projections num_new_candidates (int): Number of new candidates for the reservoir subsampling_factor (int): Factor to subsample points for SW computation. Higher values reduce memory usage but may affect quality. sampling_mode (str): Sampling mode for VectorSWDLoss. pipe (SWStableDiffusion3Pipeline): Pipeline to use for generation. If None, a new pipeline is created. compile (bool): Whether to compile the pipeline. Returns: PIL.Image.Image: Generated image """ # Normalize device to torch.device for robustness device = torch.device(device) if not isinstance(device, torch.device) else device if pipe is None: pipe = create_pipeline(model_path, device, compile=compile) pipe.setup_swd( num_projections=num_projections, use_ucv=use_ucv, use_lcv=use_lcv, distance=distance, num_new_candidates=num_new_candidates, subsampling_factor=subsampling_factor, sampling_mode=sampling_mode, ) if seed is not None: print(f"Using seed: {seed}") generator = torch.Generator(device=device).manual_seed(seed) else: generator = None image = pipe( prompt=prompt, num_inference_steps=num_inference_steps, num_guided_steps=num_guided_steps, guidance_scale=guidance_scale, cfg_rescale_phi=cfg_rescale_phi, sw_u_lr=sw_u_lr, sw_steps=sw_steps, height=height, width=width, sw_reference=reference_image, generator=generator, write_video_animation_path=video_animation_path, ).images[0] return image def create_pipeline(model_path, device: str = "cuda", compile: bool = False): pipe = SWStableDiffusion3Pipeline.from_pretrained( model_path, torch_dtype=torch.bfloat16, ) pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) pipe.to(device) if compile: pipe.transformer = torch.compile(pipe.transformer) pipe.vae.decoder = torch.compile(pipe.vae.decoder) return pipe