Julian Bilcke
we are going to hack into finetrainers
9fd1204
raw
history blame
6.17 kB
import math
from typing import Optional, Union
import torch
from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler
from diffusers.training_utils import compute_loss_weighting_for_sd3
# Default values copied from https://github.com/huggingface/diffusers/blob/8957324363d8b239d82db4909fbf8c0875683e3d/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L47
def resolution_dependent_timestep_flow_shift(
latents: torch.Tensor,
sigmas: torch.Tensor,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
) -> torch.Tensor:
image_or_video_sequence_length = 0
if latents.ndim == 4:
image_or_video_sequence_length = latents.shape[2] * latents.shape[3]
elif latents.ndim == 5:
image_or_video_sequence_length = latents.shape[2] * latents.shape[3] * latents.shape[4]
else:
raise ValueError(f"Expected 4D or 5D tensor, got {latents.ndim}D tensor")
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
b = base_shift - m * base_image_seq_len
mu = m * image_or_video_sequence_length + b
sigmas = default_flow_shift(latents, sigmas, shift=mu)
return sigmas
def default_flow_shift(sigmas: torch.Tensor, shift: float = 1.0) -> torch.Tensor:
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
return sigmas
def compute_density_for_timestep_sampling(
weighting_scheme: str,
batch_size: int,
logit_mean: float = None,
logit_std: float = None,
mode_scale: float = None,
device: torch.device = torch.device("cpu"),
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device=device, generator=generator)
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device=device, generator=generator)
return u
def get_scheduler_alphas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor:
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
return None
elif isinstance(scheduler, CogVideoXDDIMScheduler):
return scheduler.alphas_cumprod.clone()
else:
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
def get_scheduler_sigmas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor:
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
return scheduler.sigmas.clone()
elif isinstance(scheduler, CogVideoXDDIMScheduler):
return scheduler.timesteps.clone().float() / float(scheduler.config.num_train_timesteps)
else:
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
def prepare_sigmas(
scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
sigmas: torch.Tensor,
batch_size: int,
num_train_timesteps: int,
flow_weighting_scheme: str = "none",
flow_logit_mean: float = 0.0,
flow_logit_std: float = 1.0,
flow_mode_scale: float = 1.29,
device: torch.device = torch.device("cpu"),
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
weights = compute_density_for_timestep_sampling(
weighting_scheme=flow_weighting_scheme,
batch_size=batch_size,
logit_mean=flow_logit_mean,
logit_std=flow_logit_std,
mode_scale=flow_mode_scale,
device=device,
generator=generator,
)
indices = (weights * num_train_timesteps).long()
elif isinstance(scheduler, CogVideoXDDIMScheduler):
# TODO(aryan): Currently, only uniform sampling is supported. Add more sampling schemes.
weights = torch.rand(size=(batch_size,), device=device, generator=generator)
indices = (weights * num_train_timesteps).long()
else:
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
return sigmas[indices]
def prepare_loss_weights(
scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
alphas: Optional[torch.Tensor] = None,
sigmas: Optional[torch.Tensor] = None,
flow_weighting_scheme: str = "none",
) -> torch.Tensor:
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme)
elif isinstance(scheduler, CogVideoXDDIMScheduler):
# SNR is computed as (alphas / (1 - alphas)), but for some reason CogVideoX uses 1 / (1 - alphas).
# TODO(aryan): Experiment if using alphas / (1 - alphas) gives better results.
return 1 / (1 - alphas)
else:
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
def prepare_target(
scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
noise: torch.Tensor,
latents: torch.Tensor,
) -> torch.Tensor:
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
target = noise - latents
elif isinstance(scheduler, CogVideoXDDIMScheduler):
target = latents
else:
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
return target
def _enable_vae_memory_optimizations(vae, enable_slicing: bool = False, enable_tiling: bool = False):
if hasattr(vae, "enable_slicing") and enable_slicing:
vae.enable_slicing()
if hasattr(vae, "enable_tiling") and enable_tiling:
vae.enable_tiling()