FLUX-VisionReply / gdf /readme.md
gokaygokay's picture
full_files
2f4febc
|
raw
history blame
3.42 kB

Generic Diffusion Framework (GDF)

Basic usage

GDF is a simple framework for working with diffusion models. It implements most common diffusion frameworks (DDPM / DDIM , EDM, Rectified Flows, etc.) and makes it very easy to switch between them or combine different parts of different frameworks

Using GDF is very straighforward, first of all just define an instance of the GDF class:

from gdf import GDF
from gdf import CosineSchedule
from gdf import VPScaler, EpsilonTarget, CosineTNoiseCond, P2LossWeight

gdf = GDF(
    schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
    input_scaler=VPScaler(), target=EpsilonTarget(),
    noise_cond=CosineTNoiseCond(),
    loss_weight=P2LossWeight(),
)

You need to define the following components:

  • Train Schedule: This will return the logSNR schedule that will be used during training, some of the schedulers can be configured. A train schedule will then be called with a batch size and will randomly sample some values from the defined distribution.
  • Sample Schedule: This is the schedule that will be used later on when sampling. It might be different from the training schedule.
  • Input Scaler: If you want to use Variance Preserving or LERP (rectified flows)
  • Target: What the target is during training, usually: epsilon, x0 or v
  • Noise Conditioning: You could directly pass the logSNR to your model but usually a normalized value is used instead, for example the EDM framework proposes to use -logSNR/8
  • Loss Weight: There are many proposed loss weighting strategies, here you define which one you'll use

All of those classes are actually very simple logSNR centric definitions, for example the VPScaler is defined as just:

class VPScaler():
    def __call__(self, logSNR): 
        a_squared = logSNR.sigmoid()
        a = a_squared.sqrt()
        b = (1-a_squared).sqrt()
        return a, b

So it's very easy to extend this framework with custom schedulers, scalers, targets, loss weights, etc...

Training

When you define your training loop you can get all you need by just doing:

shift, loss_shift = 1, 1 # this can be set to higher values as per what the Simple Diffusion paper sugested for high resolution
for inputs, extra_conditions in dataloader_iterator:
    noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(inputs, shift=shift, loss_shift=loss_shift) 
    pred = diffusion_model(noised, noise_cond, extra_conditions)

    loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
    loss_adjusted = (loss * loss_weight).mean()

    loss_adjusted.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

And that's all, you have a diffusion model training, where it's very easy to customize the different elements of the training from the GDF class.

Sampling

The other important part is sampling, when you want to use this framework to sample you can just do the following:

from gdf import DDPMSampler

shift = 1
sampling_configs = {
    "timesteps": 30, "cfg": 7,  "sampler": DDPMSampler(gdf), "shift": shift,
    "schedule": CosineSchedule(clamp_range=[0.0001, 0.9999])
}

*_, (sampled, _, _) = gdf.sample(
    diffusion_model, {"cond": extra_conditions}, latents.shape, 
    unconditional_inputs= {"cond": torch.zeros_like(extra_conditions)}, 
    device=device, **sampling_configs
)

Available modules

TODO