import torch from torchvision.utils import make_grid import math from PIL import Image from diffusion import create_diffusion from diffusers.models import AutoencoderKL import gradio as gr from imagenet_class_data import IMAGENET_1K_CLASSES from download import find_model from models import DiT_XL_2 def load_model(image_size=256): assert image_size in [256, 512] latent_size = image_size // 8 model = DiT_XL_2(input_size=latent_size).to(device) state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt") model.load_state_dict(state_dict) model.eval() return model torch.set_grad_enabled(False) device = "cuda" if torch.cuda.is_available() else "cpu" model = load_model(image_size=256) vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device) current_image_size = 256 current_vae_model = "stabilityai/sd-vae-ft-mse" def generate(image_size, vae_model, class_label, cfg_scale, num_sampling_steps, n, seed): image_size = int(image_size.split("x")[0]) global current_image_size if image_size != current_image_size: global model del model # if device == "cuda": # torch.cuda.empty_cache() model = load_model(image_size=image_size) current_image_size = image_size global current_vae_model if vae_model != current_vae_model: global vae if device == "cuda": vae.to("cpu") del vae vae = AutoencoderKL.from_pretrained(vae_model).to(device) # Seed PyTorch: torch.manual_seed(seed) # Setup diffusion diffusion = create_diffusion(str(num_sampling_steps)) # Create sampling noise: latent_size = image_size // 8 z = torch.randn(n, 4, latent_size, latent_size, device=device) y = torch.tensor([class_label] * n, device=device) # Setup classifier-free guidance: z = torch.cat([z, z], 0) y_null = torch.tensor([1000] * n, device=device) y = torch.cat([y, y_null], 0) model_kwargs = dict(y=y, cfg_scale=cfg_scale) # Sample images: samples = diffusion.p_sample_loop( model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device ) samples, _ = samples.chunk(2, dim=0) # Remove null class samples samples = vae.decode(samples / 0.18215).sample # Convert to PIL.Image format: samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy() samples = [Image.fromarray(sample) for sample in samples] return samples description = '''This is a demo of our DiT image generation models. DiTs are a new class of diffusion models with transformer backbones. They are class-conditional models trained on ImageNet-1K, and they outperform prior DDPMs.''' duplicate = '''Skip the queue by duplicating this space and upgrading to GPU in settings ''' project_links = '''
Project Page · Colab · Paper · GitHub
''' examples = [ ["512x512", "stabilityai/sd-vae-ft-mse", "golden retriever", 4.0, 200, 4, 1000], ["512x512", "stabilityai/sd-vae-ft-mse", "macaw", 4.0, 200, 4, 1], ["512x512", "stabilityai/sd-vae-ft-mse", "balloon", 4.0, 200, 4, 1], ["512x512", "stabilityai/sd-vae-ft-mse", "cliff, drop, drop-off", 4.0, 200, 4, 7], ["512x512", "stabilityai/sd-vae-ft-mse", "Pembroke, Pembroke Welsh corgi", 4.0, 200, 4, 0], ["256x256", "stabilityai/sd-vae-ft-mse", "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", 4.0, 200, 4, 1], ["256x256", "stabilityai/sd-vae-ft-mse", "teddy, teddy bear", 4.0, 200, 4, 3], ["256x256", "stabilityai/sd-vae-ft-mse", "cheeseburger", 4.0, 200, 4, 2], ] with gr.Blocks() as demo: gr.Markdown("