#!/usr/bin/env python3 from diffusers import UNet2DModel, DDIMScheduler, VQModel import torch import PIL.Image import numpy as np import tqdm seed = 3 # 1. Unroll the full loop # ================================================================== # load all models unet = UNet2DModel.from_pretrained("./", subfolder="unet") vqvae = VQModel.from_pretrained("./", subfolder="vqvae") scheduler = DDIMScheduler.from_config("./", subfolder="scheduler") # set to cuda torch_device = "cuda" if torch.cuda.is_available() else "cpu" unet.to(torch_device) vqvae.to(torch_device) # generate gaussian noise to be decoded generator = torch.manual_seed(seed) noise = torch.randn( (1, unet.in_channels, unet.image_size, unet.image_size), generator=generator, ).to(torch_device) # set inference steps for DDIM scheduler.set_timesteps(num_inference_steps=200) image = noise for t in tqdm.tqdm(scheduler.timesteps): # predict noise residual of previous image with torch.no_grad(): residual = unet(image, t)["sample"] # compute previous image x_t according to DDIM formula prev_image = scheduler.step(residual, t, image, eta=0.0)["prev_sample"] # x_t-1 -> x_t image = prev_image # decode image with vae with torch.no_grad(): image = vqvae.decode(image) # process image image_processed = image.cpu().permute(0, 2, 3, 1) image_processed = (image_processed + 1.0) * 127.5 image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8) image_pil = PIL.Image.fromarray(image_processed[0]) # 2. Use pipeline # ================================================================== from diffusers import LatentDiffusionUncondPipeline import torch import PIL.Image import numpy as np import tqdm pipeline = LatentDiffusionUncondPipeline.from_pretrained("./") # generatae image by calling the pipeline generator = torch.manual_seed(seed) image = pipeline(generator=generator, num_inference_steps=200)["sample"] # process image image_processed = image.cpu().permute(0, 2, 3, 1) image_processed = (image_processed + 1.0) * 127.5 image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8) image_pil = PIL.Image.fromarray(image_processed[0]) image_pil.save(f"generated_image_{seed}.png")