Consistency Models (YSDA CV Week 2024)

This repository contains the weights of the models trained as the final task of the YSDA CV Week 2024.

Consistency Models were trained based on the Stable Diffusion 1.5 (SD 1.5) checkpoint: "sd-legacy/stable-diffusion-v1-5".

The training consisted of additional LoRA modules of rank 64 on top of some of the layers of the main model. We have considered three different variants of Consistency Models:

  1. Consistency Training
  2. Consistency Distillation
  3. Multi-boundary Consistency Distillation

We trained each of them on the 5k subset from COCO dataset. For each of the models, the weights of the corresponding LoRA adapter have been preserved in the usual PEFT format.

You can reproduce the generation results for 3) Multi-boundary Consistency Distillation as follows:

%matplotlib inline
import matplotlib.pyplot as plt

import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from peft import PeftModel

def visualize_images(images):
    assert len(images) == 4
    plt.figure(figsize=(12, 3))
    for i, image in enumerate(images):
        plt.subplot(1, 4, i+1)
        plt.imshow(image)
        plt.axis('off')

    plt.subplots_adjust(wspace=-0.01, hspace=-0.01)

pipe = StableDiffusionPipeline.from_pretrained("sd-legacy/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")

pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler.timesteps = pipe.scheduler.timesteps.cuda()
pipe.scheduler.alphas_cumprod = pipe.scheduler.alphas_cumprod.cuda()

loaded_cm_unet = PeftModel.from_pretrained(
    pipe.unet.to(torch.float32),
    "kisnikser/consistency-models",
    subfolder="multi-cd",
    adapter_name="multi-cd",
)

pipe.unet = loaded_cm_unet.eval().to(torch.float16)

validation_prompts = [
    "A sad puppy with large eyes",
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
    "A girl with pale blue hair and a cami tank top",
    "A lighthouse in a giant wave, origami style",
    "belle epoque, christmas, red house in the forest, photo realistic, 8k",
    "A small cactus with a happy face in the Sahara desert",
    "Green commercial building with refrigerator and refrigeration units outside",
]

for prompt in validation_prompts:
    generator = torch.Generator(device="cuda").manual_seed(1)
    images = pipe(
        prompt=prompt,
        guidance_scale=1.0,
        num_inference_steps=4,
        generator=generator,
        num_images_per_prompt=4
    ).images
    visualize_images(images)

image/png image/png image/png image/png image/png image/png image/png image/png

Downloads last month
0
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for kisnikser/consistency-models

Finetuned
(53)
this model

Dataset used to train kisnikser/consistency-models