|
|
import math |
|
|
from typing import Optional, Union |
|
|
|
|
|
import torch |
|
|
from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler |
|
|
from diffusers.training_utils import compute_loss_weighting_for_sd3 |
|
|
|
|
|
|
|
|
|
|
|
def resolution_dependent_timestep_flow_shift( |
|
|
latents: torch.Tensor, |
|
|
sigmas: torch.Tensor, |
|
|
base_image_seq_len: int = 256, |
|
|
max_image_seq_len: int = 4096, |
|
|
base_shift: float = 0.5, |
|
|
max_shift: float = 1.15, |
|
|
) -> torch.Tensor: |
|
|
image_or_video_sequence_length = 0 |
|
|
if latents.ndim == 4: |
|
|
image_or_video_sequence_length = latents.shape[2] * latents.shape[3] |
|
|
elif latents.ndim == 5: |
|
|
image_or_video_sequence_length = latents.shape[2] * latents.shape[3] * latents.shape[4] |
|
|
else: |
|
|
raise ValueError(f"Expected 4D or 5D tensor, got {latents.ndim}D tensor") |
|
|
|
|
|
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len) |
|
|
b = base_shift - m * base_image_seq_len |
|
|
mu = m * image_or_video_sequence_length + b |
|
|
sigmas = default_flow_shift(latents, sigmas, shift=mu) |
|
|
return sigmas |
|
|
|
|
|
|
|
|
def default_flow_shift(sigmas: torch.Tensor, shift: float = 1.0) -> torch.Tensor: |
|
|
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) |
|
|
return sigmas |
|
|
|
|
|
|
|
|
def compute_density_for_timestep_sampling( |
|
|
weighting_scheme: str, |
|
|
batch_size: int, |
|
|
logit_mean: float = None, |
|
|
logit_std: float = None, |
|
|
mode_scale: float = None, |
|
|
device: torch.device = torch.device("cpu"), |
|
|
generator: Optional[torch.Generator] = None, |
|
|
) -> torch.Tensor: |
|
|
r""" |
|
|
Compute the density for sampling the timesteps when doing SD3 training. |
|
|
|
|
|
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. |
|
|
|
|
|
SD3 paper reference: https://arxiv.org/abs/2403.03206v1. |
|
|
""" |
|
|
if weighting_scheme == "logit_normal": |
|
|
|
|
|
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator) |
|
|
u = torch.nn.functional.sigmoid(u) |
|
|
elif weighting_scheme == "mode": |
|
|
u = torch.rand(size=(batch_size,), device=device, generator=generator) |
|
|
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) |
|
|
else: |
|
|
u = torch.rand(size=(batch_size,), device=device, generator=generator) |
|
|
return u |
|
|
|
|
|
|
|
|
def get_scheduler_alphas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor: |
|
|
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): |
|
|
return None |
|
|
elif isinstance(scheduler, CogVideoXDDIMScheduler): |
|
|
return scheduler.alphas_cumprod.clone() |
|
|
else: |
|
|
raise ValueError(f"Unsupported scheduler type {type(scheduler)}") |
|
|
|
|
|
|
|
|
def get_scheduler_sigmas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor: |
|
|
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): |
|
|
return scheduler.sigmas.clone() |
|
|
elif isinstance(scheduler, CogVideoXDDIMScheduler): |
|
|
return scheduler.timesteps.clone().float() / float(scheduler.config.num_train_timesteps) |
|
|
else: |
|
|
raise ValueError(f"Unsupported scheduler type {type(scheduler)}") |
|
|
|
|
|
|
|
|
def prepare_sigmas( |
|
|
scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler], |
|
|
sigmas: torch.Tensor, |
|
|
batch_size: int, |
|
|
num_train_timesteps: int, |
|
|
flow_weighting_scheme: str = "none", |
|
|
flow_logit_mean: float = 0.0, |
|
|
flow_logit_std: float = 1.0, |
|
|
flow_mode_scale: float = 1.29, |
|
|
device: torch.device = torch.device("cpu"), |
|
|
generator: Optional[torch.Generator] = None, |
|
|
) -> torch.Tensor: |
|
|
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): |
|
|
weights = compute_density_for_timestep_sampling( |
|
|
weighting_scheme=flow_weighting_scheme, |
|
|
batch_size=batch_size, |
|
|
logit_mean=flow_logit_mean, |
|
|
logit_std=flow_logit_std, |
|
|
mode_scale=flow_mode_scale, |
|
|
device=device, |
|
|
generator=generator, |
|
|
) |
|
|
indices = (weights * num_train_timesteps).long() |
|
|
elif isinstance(scheduler, CogVideoXDDIMScheduler): |
|
|
|
|
|
weights = torch.rand(size=(batch_size,), device=device, generator=generator) |
|
|
indices = (weights * num_train_timesteps).long() |
|
|
else: |
|
|
raise ValueError(f"Unsupported scheduler type {type(scheduler)}") |
|
|
|
|
|
return sigmas[indices] |
|
|
|
|
|
|
|
|
def prepare_loss_weights( |
|
|
scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler], |
|
|
alphas: Optional[torch.Tensor] = None, |
|
|
sigmas: Optional[torch.Tensor] = None, |
|
|
flow_weighting_scheme: str = "none", |
|
|
) -> torch.Tensor: |
|
|
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): |
|
|
return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme) |
|
|
elif isinstance(scheduler, CogVideoXDDIMScheduler): |
|
|
|
|
|
|
|
|
return 1 / (1 - alphas) |
|
|
else: |
|
|
raise ValueError(f"Unsupported scheduler type {type(scheduler)}") |
|
|
|
|
|
|
|
|
def prepare_target( |
|
|
scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler], |
|
|
noise: torch.Tensor, |
|
|
latents: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): |
|
|
target = noise - latents |
|
|
elif isinstance(scheduler, CogVideoXDDIMScheduler): |
|
|
target = latents |
|
|
else: |
|
|
raise ValueError(f"Unsupported scheduler type {type(scheduler)}") |
|
|
|
|
|
return target |
|
|
|
|
|
|
|
|
def _enable_vae_memory_optimizations(vae, enable_slicing: bool = False, enable_tiling: bool = False): |
|
|
if hasattr(vae, "enable_slicing") and enable_slicing: |
|
|
vae.enable_slicing() |
|
|
if hasattr(vae, "enable_tiling") and enable_tiling: |
|
|
vae.enable_tiling() |
|
|
|