import os os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion") os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP") import argparse from functools import partial from pathlib import Path import sys sys.path.append('./cloob-latent-diffusion') sys.path.append('./cloob-latent-diffusion/cloob-training') sys.path.append('./cloob-latent-diffusion/latent-diffusion') sys.path.append('./cloob-latent-diffusion/taming-transformers') sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch') from omegaconf import OmegaConf from PIL import Image import torch from torch import nn from torch.nn import functional as F from torchvision import transforms from torchvision.transforms import functional as TF from tqdm import trange from CLIP import clip from cloob_training import model_pt, pretrained import ldm.models.autoencoder from diffusion import sampling, utils import train_latent_diffusion as train from huggingface_hub import hf_hub_url, cached_download import random # Download the model files checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt")) ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt")) ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml")) # Define a few utility functions def parse_prompt(prompt, default_weight=3.): if prompt.startswith('http://') or prompt.startswith('https://'): vals = prompt.rsplit(':', 2) vals = [vals[0] + ':' + vals[1], *vals[2:]] else: vals = prompt.rsplit(':', 1) vals = vals + ['', default_weight][len(vals):] return vals[0], float(vals[1]) def resize_and_center_crop(image, size): fac = max(size[0] / image.size[0], size[1] / image.size[1]) image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS) return TF.center_crop(image, size[::-1]) # Load the models device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print('Using device:', device) print('loading models') # autoencoder ae_config = OmegaConf.load(ae_config_path) ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params) ae_model.eval().requires_grad_(False).to(device) ae_model.load_state_dict(torch.load(ae_model_path)) n_ch, side_y, side_x = 4, 32, 32 # diffusion model model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084)) model.load_state_dict(torch.load(checkpoint, map_location='cpu')) model = model.to(device).eval().requires_grad_(False) # CLOOB cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs') cloob = model_pt.get_pt_model(cloob_config) checkpoint = pretrained.download_checkpoint(cloob_config) cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint)) cloob.eval().requires_grad_(False).to(device) # The key function: returns a list of n PIL images def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='plms', eta=None): zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device) target_embeds, weights = [zero_embed], [] for prompt in prompts: txt, weight = parse_prompt(prompt) target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float()) weights.append(weight) for prompt in images: path, weight = parse_prompt(prompt) img = Image.open(utils.fetch(path)).convert('RGB') clip_size = cloob.config['image_encoder']['image_size'] img = resize_and_center_crop(img, (clip_size, clip_size)) batch = TF.to_tensor(img)[None].to(device) embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1) target_embeds.append(embed) weights.append(weight) weights = torch.tensor([1 - sum(weights), *weights], device=device) torch.manual_seed(seed) def cfg_model_fn(x, t): n = x.shape[0] n_conds = len(target_embeds) x_in = x.repeat([n_conds, 1, 1, 1]) t_in = t.repeat([n_conds]) clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0) vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]]) v = vs.mul(weights[:, None, None, None, None]).sum(0) return v def run(x, steps): if method == 'ddpm': return sampling.sample(cfg_model_fn, x, steps, 1., {}) if method == 'ddim': return sampling.sample(cfg_model_fn, x, steps, eta, {}) if method == 'prk': return sampling.prk_sample(cfg_model_fn, x, steps, {}) if method == 'plms': return sampling.plms_sample(cfg_model_fn, x, steps, {}) if method == 'pie': return sampling.pie_sample(cfg_model_fn, x, steps, {}) if method == 'plms2': return sampling.plms2_sample(cfg_model_fn, x, steps, {}) assert False batch_size = n x = torch.randn([n, n_ch, side_y, side_x], device=device) t = torch.linspace(1, 0, steps + 1, device=device)[:-1] steps = utils.get_spliced_ddpm_cosine_schedule(t) pil_ims = [] for i in trange(0, n, batch_size): cur_batch_size = min(n - i, batch_size) out_latents = run(x[i:i+cur_batch_size], steps) outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device)) for j, out in enumerate(outs): pil_ims.append(utils.to_pil_image(out)) return pil_ims import gradio as gr def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'): if seed == None : seed = random.randint(0, 10000) print( prompt, im_prompt, seed, n_steps) prompts = [prompt] im_prompts = [] if im_prompt != None: im_prompts = [im_prompt] pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method) return pil_ims[0] iface = gr.Interface(fn=gen_ims, inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"), #gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0), gr.inputs.Textbox(label="Text prompt"), gr.inputs.Image(optional=True, label="Image prompt", type='filepath'), #gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps") ], outputs=[gr.outputs.Image(type="pil", label="Generated Image")], examples=[ ["Virgin and Child, in the style of Jacopo Bellini"], ["Katsushika Hokusai, The Dragon of Smoke Escaping from Mount Fuji"], ["Moon Light Sonata by Basuki Abdullah"], ["Twon Tree by M.C. Escher"], ["Futurism, in the style of Wassily Kandinsky"], ["Art Nouveau, in the style of John Singer Sargent"], ["Surrealism, in the style of Edgar Degas"], ["Expressionism, in the style of Wassily Kandinsky"], ["Futurism, in the style of Egon Schiele"], ["Neoclassicism, in the style of Gustav Klimt"], ["Cubism, in the style of Gustav Klimt"], ["Op Art, in the style of Marc Chagall"], ["Romanticism, in the style of M.C. Escher"], ["Futurism, in the style of M.C. Escher"], ["Abstract Art, in the style of M.C. Escher"], ["Mannerism, in the style of Paul Klee"], ["Romanesque Art, in the style of Leonardo da Vinci"], ["High Renaissance, in the style of Rembrandt"], ["Magic Realism, in the style of Gustave Dore"], ["Realism, in the style of Jean-Michel Basquiat"], ["Art Nouveau, in the style of Paul Gauguin"], ["Avant-garde, in the style of Pierre-Auguste Renoir"], ["Baroque, in the style of Edward Hopper"], ["Post-Impressionism, in the style of Wassily Kandinsky"], ["Naturalism, in the style of Rene Magritte"], ["Constructivism, in the style of Paul Cezanne"], ["Abstract Expressionism, in the style of Henri Matisse"], ["Pop Art, in the style of Vincent van Gogh"], ["Futurism, in the style of Wassily Kandinsky"], ["Futurism, in the style of Zdzislaw Beksinski"], ['Surrealism, in the style of Salvador Dali'], ["Aaron Wacker, oil on canvas"], ["abstract"], ["landscape"], ["portrait"], ["sculpture"], ["genre painting"], ["installation"], ["photo"], ["figurative"], ["illustration"], ["still life"], ["history painting"], ["cityscape"], ["marina"], ["animal painting"], ["design"], ["calligraphy"], ["symbolic painting"], ["graffiti"], ["performance"], ["mythological painting"], ["battle painting"], ["self-portrait"], ["Impressionism, oil on canvas"] ], title='Art Generator and Style Mixer from 🧠 Cloob and 🎨 WikiArt - Visual Art Encyclopedia:', description="Trained on images from the [WikiArt](https://www.wikiart.org/) dataset, comprised of visual arts", article = 'Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa)..' ) iface.launch(enable_queue=True) # , debug=True for colab debugging