import argparse, os, sys, glob import torch import numpy as np from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange from einops import rearrange from torchvision.utils import make_grid from pytorch_lightning import seed_everything from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler import pdb import random def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) model.cuda() model.eval() return model if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render" ) parser.add_argument( "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples" ) parser.add_argument( "--ddim_steps", type=int, default=200, help="number of ddim sampling steps", ) parser.add_argument( "--plms", action='store_true', help="use plms sampling", ) parser.add_argument( "--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) parser.add_argument( "--n_iter", type=int, default=1, help="sample this often", ) parser.add_argument( "--H", type=int, default=256, help="image height, in pixel space", ) parser.add_argument( "--W", type=int, default=256, help="image width, in pixel space", ) parser.add_argument( "--n_samples", type=int, default=4, help="how many samples to produce for the given prompt", ) parser.add_argument( "--scale", type=float, default=5.0, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", ) parser.add_argument( "--seed", type=int, default=42, help="the seed (for reproducible sampling)", ) parser.add_argument( "--ema", action='store_true', ) parser.add_argument( "--ckpt_path", type=str, default="/data/pretrained_models/ldm/text2img-large/model.ckpt", help="Path to pretrained ldm text2img model") parser.add_argument( "--embedding_path", type=str, help="Path to a pre-trained embedding manager checkpoint") opt = parser.parse_args() seed = random.randint(1,500) seed_everything(seed) # config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic config = OmegaConf.load("configs/finetune/finetune_generic.yaml") # config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-finetune2.yaml") model = load_model_from_config(config, opt.ckpt_path) # TODO: check path # model = load_model_from_config(config, "ckpt/model.ckpt") # model.cond_stage_model.load_state_dict(checkpoint['model_state_dict']) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) if opt.plms: sampler = PLMSSampler(model) else: sampler = DDIMSampler(model, device=device) os.makedirs(opt.outdir, exist_ok=True) outpath = opt.outdir prompt = opt.prompt sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) base_count = len(os.listdir(sample_path)) start_code = None # start_code = torch.randn([opt.n_samples, 4, 32, 32], device=device) all_samples=list() with torch.no_grad(): with model.ema_scope(): uc = None if opt.scale != 1.0: uc = model.get_learned_conditioning(opt.n_samples * [""]) for n in trange(opt.n_iter, desc="Sampling"): c = model.get_learned_conditioning(opt.n_samples * [prompt]) shape = [4, opt.H//8, opt.W//8] samples_ddim, _ = sampler.sample(S=opt.ddim_steps, conditioning=c, batch_size=opt.n_samples, shape=shape, verbose=False, unconditional_guidance_scale=opt.scale, unconditional_conditioning=uc, eta=opt.ddim_eta) x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) for x_sample in x_samples_ddim: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png")) base_count += 1 all_samples.append(x_samples_ddim) prompt = prompt # additionally, save as grid grid = torch.stack(all_samples, 0) grid = rearrange(grid, 'n b c h w -> (n b) c h w') grid = make_grid(grid, nrow=opt.n_samples) # to image grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.jpg')) print(f"Your samples are ready and waiting four you here: \n{outpath} \nEnjoy.")