File size: 7,433 Bytes
4fbd61f
 
 
fee4293
4fbd61f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01d0d8e
4fbd61f
 
 
01d0d8e
 
 
4fbd61f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abfd4d8
4fbd61f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os

os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning==1.6.5 einops wandb ftfy regex ./CLIP")

import argparse
from functools import partial
from pathlib import Path
import sys
sys.path.append('./cloob-latent-diffusion')
sys.path.append('./cloob-latent-diffusion/cloob-training')
sys.path.append('./cloob-latent-diffusion/latent-diffusion')
sys.path.append('./cloob-latent-diffusion/taming-transformers')
sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
from omegaconf import OmegaConf
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm import trange
from CLIP import clip
from cloob_training import model_pt, pretrained
import ldm.models.autoencoder
from diffusion import sampling, utils
import train_latent_diffusion as train
from huggingface_hub import hf_hub_url, hf_hub_download
import random

# Download the model files
checkpoint = hf_hub_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
ae_model_path = hf_hub_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
ae_config_path = hf_hub_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))

# Define a few utility functions

def parse_prompt(prompt, default_weight=3.):
    if prompt.startswith('http://') or prompt.startswith('https://'):
        vals = prompt.rsplit(':', 2)
        vals = [vals[0] + ':' + vals[1], *vals[2:]]
    else:
        vals = prompt.rsplit(':', 1)
    vals = vals + ['', default_weight][len(vals):]
    return vals[0], float(vals[1])


def resize_and_center_crop(image, size):
    fac = max(size[0] / image.size[0], size[1] / image.size[1])
    image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
    return TF.center_crop(image, size[::-1])


# Load the models
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('loading models')
# autoencoder
ae_config = OmegaConf.load(ae_config_path)
ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
ae_model.eval().requires_grad_(False).to(device)
ae_model.load_state_dict(torch.load(ae_model_path))
n_ch, side_y, side_x = 4, 32, 32

# diffusion model
model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
model = model.to(device).eval().requires_grad_(False)

# CLOOB
cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
cloob = model_pt.get_pt_model(cloob_config)
checkpoint = pretrained.download_checkpoint(cloob_config)
cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
cloob.eval().requires_grad_(False).to(device)


# The key function: returns a list of n PIL images
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15,
             method='plms', eta=None):
  zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
  target_embeds, weights = [zero_embed], []

  for prompt in prompts:
      txt, weight = parse_prompt(prompt)
      target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
      weights.append(weight)

  for prompt in images:
      path, weight = parse_prompt(prompt)
      img = Image.open(utils.fetch(path)).convert('RGB')
      clip_size = cloob.config['image_encoder']['image_size']
      img = resize_and_center_crop(img, (clip_size, clip_size))
      batch = TF.to_tensor(img)[None].to(device)
      embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
      target_embeds.append(embed)
      weights.append(weight)

  weights = torch.tensor([1 - sum(weights), *weights], device=device)

  torch.manual_seed(seed)

  def cfg_model_fn(x, t):
      n = x.shape[0]
      n_conds = len(target_embeds)
      x_in = x.repeat([n_conds, 1, 1, 1])
      t_in = t.repeat([n_conds])
      clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
      vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
      v = vs.mul(weights[:, None, None, None, None]).sum(0)
      return v

  def run(x, steps):
      if method == 'ddpm':
          return sampling.sample(cfg_model_fn, x, steps, 1., {})
      if method == 'ddim':
          return sampling.sample(cfg_model_fn, x, steps, eta, {})
      if method == 'prk':
          return sampling.prk_sample(cfg_model_fn, x, steps, {})
      if method == 'plms':
          return sampling.plms_sample(cfg_model_fn, x, steps, {})
      if method == 'pie':
          return sampling.pie_sample(cfg_model_fn, x, steps, {})
      if method == 'plms2':
          return sampling.plms2_sample(cfg_model_fn, x, steps, {})
      assert False

  batch_size = n
  x = torch.randn([n, n_ch, side_y, side_x], device=device)
  t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
  steps = utils.get_spliced_ddpm_cosine_schedule(t)
  pil_ims = []
  for i in trange(0, n, batch_size):
      cur_batch_size = min(n - i, batch_size)
      out_latents = run(x[i:i+cur_batch_size], steps)
      outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
      for j, out in enumerate(outs):
          pil_ims.append(utils.to_pil_image(out))

  return pil_ims
  
  
import gradio as gr

def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
  if seed == None :
    seed = random.randint(0, 10000)
  print( prompt, im_prompt, seed, n_steps)
  prompts = [prompt]
  im_prompts = []
  if im_prompt != None:
    im_prompts = [im_prompt]
  pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
  return pil_ims[0]

iface = gr.Interface(fn=gen_ims, 
  inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"),
          #gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0),
          gr.inputs.Textbox(label="Text prompt"),
          gr.inputs.Image(optional=True, label="Image prompt", type='filepath'),
          #gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps")
          ], 
  outputs=[gr.outputs.Image(type="pil", label="Generated Image")],
  examples=[["An iceberg, oil on canvas"],["A martian landscape, in the style of Monet"], ['A peaceful meadow, pastel crayons'], ["A painting of a vase of flowers"], ["A ship leaving the port in the summer, oil on canvas"]],
  title='Generate art from text prompts :',
  description="By typing a text prompt or providing an image prompt, and pressing submit you can generate images based on this prompt. The model was trained on images from the [WikiArt](https://huggingface.co/datasets/huggan/wikiart) dataset, comprised mostly of paintings.",
  article = 'The model is a distilled version of a cloob-conditioned latent diffusion model fine-tuned on the WikiArt dataset. You can find more information on this model on the [model card](https://huggingface.co/huggan/distill-ccld-wa). The student model training and this demo were done by [@gigant](https://huggingface.co/gigant). The teacher model was trained by [@johnowhitaker](https://huggingface.co/johnowhitaker)'

)
iface.launch(enable_queue=True) # , debug=True for colab debugging