Spaces:
Runtime error
Runtime error
File size: 6,236 Bytes
2ea65a3 96469d5 2ea65a3 96469d5 2ea65a3 b39ce8e 2ea65a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
from torchvision.utils import make_grid
import math
from PIL import Image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
import gradio as gr
from imagenet_class_data import IMAGENET_1K_CLASSES
from download import find_model
from models import DiT_XL_2
def load_model(image_size=256):
assert image_size in [256, 512]
latent_size = image_size // 8
model = DiT_XL_2(input_size=latent_size).to(device)
state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
model.load_state_dict(state_dict)
model.eval()
return model
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model(image_size=256)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
current_image_size = 256
current_vae_model = "stabilityai/sd-vae-ft-mse"
def generate(image_size, vae_model, class_label, cfg_scale, num_sampling_steps, n, seed):
image_size = int(image_size.split("x")[0])
global current_image_size
if image_size != current_image_size:
global model
del model
# if device == "cuda":
# torch.cuda.empty_cache()
model = load_model(image_size=image_size)
current_image_size = image_size
global current_vae_model
if vae_model != current_vae_model:
global vae
if device == "cuda":
vae.to("cpu")
del vae
vae = AutoencoderKL.from_pretrained(vae_model).to(device)
# Seed PyTorch:
torch.manual_seed(seed)
# Setup diffusion
diffusion = create_diffusion(str(num_sampling_steps))
# Create sampling noise:
latent_size = image_size // 8
z = torch.randn(n, 4, latent_size, latent_size, device=device)
y = torch.tensor([class_label] * n, device=device)
# Setup classifier-free guidance:
z = torch.cat([z, z], 0)
y_null = torch.tensor([1000] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=cfg_scale)
# Sample images:
samples = diffusion.p_sample_loop(
model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
)
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
samples = vae.decode(samples / 0.18215).sample
# Convert to PIL.Image format:
samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy()
samples = [Image.fromarray(sample) for sample in samples]
return samples
description = '''This is a demo of our DiT image generation models. DiTs are a new class of diffusion models with
transformer backbones. They are class-conditional models trained on ImageNet-1K, and they outperform prior DDPMs.'''
duplicate = '''Skip the queue by duplicating this space and upgrading to GPU in settings
<a href="https://huggingface.co/spaces/wpeebles/DiT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>'''
project_links = '''
<p style="text-align: center">
<a href="https://www.wpeebles.com/DiT.html">Project Page</a> ·
<a href="http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb">Colab</a> ·
<a href="http://arxiv.org/abs/2212.09748">Paper</a> ·
<a href="https://github.com/facebookresearch/DiT">GitHub</a></p>'''
examples = [
["512x512", "stabilityai/sd-vae-ft-mse", "golden retriever", 4.0, 200, 4, 1000],
["512x512", "stabilityai/sd-vae-ft-mse", "macaw", 4.0, 200, 4, 1],
["512x512", "stabilityai/sd-vae-ft-mse", "balloon", 4.0, 200, 4, 1],
["512x512", "stabilityai/sd-vae-ft-mse", "cliff, drop, drop-off", 4.0, 200, 4, 7],
["512x512", "stabilityai/sd-vae-ft-mse", "Pembroke, Pembroke Welsh corgi", 4.0, 200, 4, 0],
["256x256", "stabilityai/sd-vae-ft-mse", "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", 4.0, 200,
4, 1],
["256x256", "stabilityai/sd-vae-ft-mse", "teddy, teddy bear", 4.0, 200, 4, 3],
["256x256", "stabilityai/sd-vae-ft-mse", "cheeseburger", 4.0, 200, 4, 2],
]
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center'>Scalable Diffusion Models with Transformers (DiT)</h1>")
gr.Markdown(project_links)
gr.Markdown(description)
gr.Markdown(duplicate)
with gr.Tabs():
with gr.TabItem('Generate'):
with gr.Row():
with gr.Column():
with gr.Row():
image_size = gr.inputs.Radio(choices=["256x256", "512x512"], default="256x256", label='DiT Model Resolution')
vae_model = gr.inputs.Radio(choices=["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"],
default="stabilityai/sd-vae-ft-mse", label='VAE Decoder')
with gr.Row():
i1k_class = gr.inputs.Dropdown(
list(IMAGENET_1K_CLASSES.values()),
default='golden retriever',
type="index", label='ImageNet-1K Class'
)
cfg_scale = gr.inputs.Slider(minimum=1, maximum=25, step=0.1, default=4.0, label='Classifier-free Guidance Scale')
steps = gr.inputs.Slider(minimum=4, maximum=1000, step=1, default=75, label='Sampling Steps')
n = gr.inputs.Slider(minimum=1, maximum=16, step=1, default=1, label='Number of Samples')
seed = gr.inputs.Number(default=0, label='Seed')
button = gr.Button("Generate", variant="primary")
with gr.Column():
output = gr.Gallery(label='Generated Images').style(grid=[2], height="auto")
button.click(generate, inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, n, seed], outputs=[output])
with gr.Row():
ex = gr.Examples(examples=examples, fn=generate,
inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, n, seed],
outputs=[output],
cache_examples=True)
demo.queue()
demo.launch()
|