#!/usr/bin/env python3 """CLIP guided sampling from k-diffusion models.""" import argparse import accelerate import clip from kornia import augmentation as KA from resize_right import resize import safetensors.torch as safetorch import torch from torch.nn import functional as F from torchvision import transforms from tqdm import trange, tqdm import k_diffusion as K def spherical_dist_loss(x, y): x = F.normalize(x, dim=-1) y = F.normalize(y, dim=-1) return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) def make_cond_model_fn(model, cond_fn): def model_fn(x, sigma, **kwargs): with torch.enable_grad(): x = x.detach().requires_grad_() denoised = model(x, sigma, **kwargs) cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) return cond_denoised return model_fn def make_static_thresh_model_fn(model, value=1.): def model_fn(x, sigma, **kwargs): return model(x, sigma, **kwargs).clamp(-value, value) return model_fn def main(): p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) p.add_argument('prompt', type=str, default='the prompt to use') p.add_argument('--batch-size', type=int, default=16, help='the batch size') p.add_argument('--checkpoint', type=str, required=True, help='the checkpoint to use') p.add_argument('--clip-guidance-scale', '-cgs', type=float, default=500., help='the CLIP guidance scale') p.add_argument('--clip-model', type=str, default='ViT-B/16', choices=clip.available_models(), help='the CLIP model to use') p.add_argument('--config', type=str, help='the model config') p.add_argument('-n', type=int, default=64, help='the number of images to sample') p.add_argument('--prefix', type=str, default='out', help='the output prefix') p.add_argument('--steps', type=int, default=100, help='the number of denoising steps') args = p.parse_args() config = K.config.load_config(args.config if args.config else args.checkpoint) model_config = config['model'] # TODO: allow non-square input sizes assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] size = model_config['input_size'] accelerator = accelerate.Accelerator() device = accelerator.device print('Using device:', device, flush=True) inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device) inner_model.load_state_dict(safetorch.load_file(args.checkpoint)) accelerator.print('Parameters:', K.utils.n_params(inner_model)) model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data']) sigma_min = model_config['sigma_min'] sigma_max = model_config['sigma_max'] clip_model = clip.load(args.clip_model, device=device)[0].eval().requires_grad_(False) clip_normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) clip_size = (clip_model.visual.input_resolution, clip_model.visual.input_resolution) aug = KA.RandomAffine(0, (1/14, 1/14), p=1, padding_mode='border') def get_image_embed(x): if x.shape[2:4] != clip_size: x = resize(x, out_shape=clip_size, pad_mode='reflect') x = clip_normalize(x) x = clip_model.encode_image(x).float() return F.normalize(x) target_embed = F.normalize(clip_model.encode_text(clip.tokenize(args.prompt, truncate=True).to(device)).float()) def cond_fn(x, t, denoised): image_embed = get_image_embed(aug(denoised.add(1).div(2))) loss = spherical_dist_loss(image_embed, target_embed).sum() * args.clip_guidance_scale grad = -torch.autograd.grad(loss, x)[0] return grad model_fn = make_cond_model_fn(model, cond_fn) model_fn = make_static_thresh_model_fn(model_fn) @torch.no_grad() @K.utils.eval_mode(model) def run(): if accelerator.is_local_main_process: tqdm.write('Sampling...') sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device) def sample_fn(n): x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigmas[0] x_0 = K.sampling.sample_dpmpp_2s_ancestral(model_fn, x, sigmas, eta=1., disable=not accelerator.is_local_main_process) return x_0 x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size) if accelerator.is_main_process: for i, out in enumerate(x_0): filename = f'{args.prefix}_{i:05}.png' K.utils.to_pil_image(out).save(filename) try: run() except KeyboardInterrupt: pass if __name__ == '__main__': main()