import torch from torchvision.utils import make_grid import math from PIL import Image from diffusion import create_diffusion from diffusers.models import AutoencoderKL import gradio as gr from imagenet_class_data import IMAGENET_1K_CLASSES from download import find_model from models import DiT_XL_2 def load_model(image_size=256): assert image_size in [256, 512] latent_size = image_size // 8 model = DiT_XL_2(input_size=latent_size).to(device) state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt") model.load_state_dict(state_dict) model.eval() return model torch.set_grad_enabled(False) device = "cuda" if torch.cuda.is_available() else "cpu" model = load_model(image_size=256) vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device) current_image_size = 256 current_vae_model = "stabilityai/sd-vae-ft-mse" def generate(image_size, vae_model, class_label, cfg_scale, num_sampling_steps, n, seed): image_size = int(image_size.split("x")[0]) global current_image_size if image_size != current_image_size: global model del model # if device == "cuda": # torch.cuda.empty_cache() model = load_model(image_size=image_size) current_image_size = image_size global current_vae_model if vae_model != current_vae_model: global vae if device == "cuda": vae.to("cpu") del vae vae = AutoencoderKL.from_pretrained(vae_model).to(device) # Seed PyTorch: torch.manual_seed(seed) # Setup diffusion diffusion = create_diffusion(str(num_sampling_steps)) # Create sampling noise: latent_size = image_size // 8 z = torch.randn(n, 4, latent_size, latent_size, device=device) y = torch.tensor([class_label] * n, device=device) # Setup classifier-free guidance: z = torch.cat([z, z], 0) y_null = torch.tensor([1000] * n, device=device) y = torch.cat([y, y_null], 0) model_kwargs = dict(y=y, cfg_scale=cfg_scale) # Sample images: samples = diffusion.p_sample_loop( model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device ) samples, _ = samples.chunk(2, dim=0) # Remove null class samples samples = vae.decode(samples / 0.18215).sample # Convert to PIL.Image format: samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy() samples = [Image.fromarray(sample) for sample in samples] return samples description = '''This is a demo of our DiT image generation models. DiTs are a new class of diffusion models with transformer backbones. They are class-conditional models trained on ImageNet-1K, and they outperform prior DDPMs.''' duplicate = '''Skip the queue by duplicating this space and upgrading to GPU in settings Duplicate Space''' project_links = '''

Project Page · Colab · Paper · GitHub

''' examples = [ ["512x512", "stabilityai/sd-vae-ft-mse", "golden retriever", 4.0, 200, 4, 1000], ["512x512", "stabilityai/sd-vae-ft-mse", "macaw", 4.0, 200, 4, 1], ["512x512", "stabilityai/sd-vae-ft-mse", "balloon", 4.0, 200, 4, 1], ["512x512", "stabilityai/sd-vae-ft-mse", "cliff, drop, drop-off", 4.0, 200, 4, 7], ["512x512", "stabilityai/sd-vae-ft-mse", "Pembroke, Pembroke Welsh corgi", 4.0, 200, 4, 0], ["256x256", "stabilityai/sd-vae-ft-mse", "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", 4.0, 200, 4, 1], ["256x256", "stabilityai/sd-vae-ft-mse", "teddy, teddy bear", 4.0, 200, 4, 3], ["256x256", "stabilityai/sd-vae-ft-mse", "cheeseburger", 4.0, 200, 4, 2], ] with gr.Blocks() as demo: gr.Markdown("

Scalable Diffusion Models with Transformers (DiT)

") gr.Markdown(project_links) gr.Markdown(description) gr.Markdown(duplicate) with gr.Tabs(): with gr.TabItem('Generate'): with gr.Row(): with gr.Column(): with gr.Row(): image_size = gr.inputs.Radio(choices=["256x256", "512x512"], default="256x256", label='DiT Model Resolution') vae_model = gr.inputs.Radio(choices=["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"], default="stabilityai/sd-vae-ft-mse", label='VAE Decoder') with gr.Row(): i1k_class = gr.inputs.Dropdown( list(IMAGENET_1K_CLASSES.values()), default='golden retriever', type="index", label='ImageNet-1K Class' ) cfg_scale = gr.inputs.Slider(minimum=1, maximum=25, step=0.1, default=4.0, label='Classifier-free Guidance Scale') steps = gr.inputs.Slider(minimum=4, maximum=1000, step=1, default=75, label='Sampling Steps') n = gr.inputs.Slider(minimum=1, maximum=16, step=1, default=1, label='Number of Samples') seed = gr.inputs.Number(default=0, label='Seed') button = gr.Button("Generate", variant="primary") with gr.Column(): output = gr.Gallery(label='Generated Images').style(grid=[2], height="auto") button.click(generate, inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, n, seed], outputs=[output]) with gr.Row(): ex = gr.Examples(examples=examples, fn=generate, inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, n, seed], outputs=[output], cache_examples=True) demo.queue() demo.launch()