File size: 8,468 Bytes
bc5a411
077fc3d
 
 
f955f91
bc5a411
cff8aa8
 
ab9e9c4
bc5a411
077fc3d
51df617
3668992
 
 
 
 
 
 
 
 
077fc3d
 
bc5a411
 
077fc3d
449a298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6f9b71
 
 
 
 
bc5a411
612ce40
bc5a411
 
 
 
568d1c7
612ce40
 
 
568d1c7
eecb1f6
ab9e9c4
 
 
 
6144b88
 
 
3a72088
6144b88
 
077fc3d
bc5a411
7af4a09
bf89172
 
6144b88
be31516
 
bf89172
5e6effb
bf89172
 
 
 
 
 
 
 
 
 
 
918aa0f
bf89172
 
 
918aa0f
bf89172
 
 
 
 
 
bc5a411
 
 
 
 
 
 
 
 
6144b88
3a72088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cd412c
3a72088
be31516
3a72088
 
 
bc5a411
 
 
e3d2366
 
 
 
3a72088
 
 
 
 
612ce40
e3d2366
bc5a411
 
 
 
077fc3d
 
 
26ca94f
077fc3d
 
 
2b1e8e5
3a72088
 
 
6144b88
077fc3d
bc5a411
be31516
 
077fc3d
aa344ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import gc
import math
import sys

#from IPython import display
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision import utils as tv_utils
from torchvision.transforms import functional as TF
import gradio as gr
from git.repo.base import Repo
from os.path import exists as path_exists

if not (path_exists(f"v-diffusion-pytorch")):
    Repo.clone_from("https://github.com/crowsonkb/v-diffusion-pytorch", "v-diffusion-pytorch")
if not (path_exists(f"CLIP")):
    Repo.clone_from("https://github.com/openai/CLIP", "CLIP")
sys.path.append('v-diffusion-pytorch')

from huggingface_hub import hf_hub_download

from CLIP import clip
from diffusion import get_model, sampling, utils

class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutout = F.adaptive_avg_pool2d(cutout, self.cut_size)
            cutouts.append(cutout)
        return torch.cat(cutouts)

def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
    
cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
#cc12m_small_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1.pth")
model = get_model('cc12m_1_cfg')()
_, side_y, side_x = model.shape
model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
model = model.half().cuda().eval().requires_grad_(False)

#model_small = get_model('cc12m_1')()
#model_small.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
#model_small = model_small.half().cuda().eval().requires_grad_(False)

clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0]
clip_model.eval().requires_grad_(False)
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                     std=[0.26862954, 0.26130258, 0.27577711])
make_cutouts = MakeCutouts(clip_model.visual.input_resolution, 16, 1.)
gc.collect()
torch.cuda.empty_cache()

def run_all(prompt, steps, n_images, weight, clip_guided):
    gc.collect()
    torch.cuda.empty_cache()
    import random
    seed = int(random.randint(0, 2147483647))
    target_embed = clip_model.encode_text(clip.tokenize(prompt).to('cuda')).float()#.cuda()
    
    if(clip_guided):
        n_images = 1
        steps = steps*5
        clip_guidance_scale = weight*100
        prompts = [prompt]
        target_embeds, weights = [], []
        def parse_prompt(prompt):
            if prompt.startswith('http://') or prompt.startswith('https://'):
                vals = prompt.rsplit(':', 2)
                vals = [vals[0] + ':' + vals[1], *vals[2:]]
            else:
                vals = prompt.rsplit(':', 1)
            vals = vals + ['', '1'][len(vals):]
            return vals[0], float(vals[1])
        
        for prompt in prompts:
            txt, weight = parse_prompt(prompt)
            target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to('cuda')).float())
            weights.append(weight)
        
        target_embeds = torch.cat(target_embeds)
        weights = torch.tensor(weights, device='cuda')
        if weights.sum().abs() < 1e-3:
            raise RuntimeError('The weights must not sum to 0.')
        weights /= weights.sum().abs()
        clip_embed = F.normalize(target_embeds.mul(weights[:, None]).sum(0, keepdim=True), dim=-1)
        clip_embed = target_embed.repeat([n_images, 1])
    
    def cfg_model_fn(x, t):
        """The CFG wrapper function."""
        n = x.shape[0]
        x_in = x.repeat([2, 1, 1, 1])
        t_in = t.repeat([2])
        clip_embed_repeat = target_embed.repeat([n, 1])
        clip_embed_in = torch.cat([torch.zeros_like(clip_embed_repeat), clip_embed_repeat])
        v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0)
        v = v_uncond + (v_cond - v_uncond) * weight
        return v   
    def make_cond_model_fn(model, cond_fn):
        def cond_model_fn(x, t, **extra_args):
            with torch.enable_grad():
                x = x.detach().requires_grad_()
                v = model(x, t, **extra_args)
                alphas, sigmas = utils.t_to_alpha_sigma(t)
                pred = x * alphas[:, None, None, None] - v * sigmas[:, None, None, None]
                cond_grad = cond_fn(x, t, pred, **extra_args).detach()
                v = v.detach() - cond_grad * (sigmas[:, None, None, None] / alphas[:, None, None, None])
            return v
        return cond_model_fn
    def cond_fn(x, t, pred, clip_embed):
        if min(pred.shape[2:4]) < 256:
            pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
        clip_in = normalize(make_cutouts((pred + 1) / 2))
        image_embeds = clip_model.encode_image(clip_in).view([16, x.shape[0], -1])
        losses = spherical_dist_loss(image_embeds, clip_embed[None])
        loss = losses.mean(0).sum() * clip_guidance_scale
        grad = -torch.autograd.grad(loss, x)[0]
        return grad
    
    torch.manual_seed(seed)
    x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
    t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
    if model.min_t == 0:
        step_list = utils.get_spliced_ddpm_cosine_schedule(t)
    else:
        step_list = utils.get_ddpm_schedule(t)
    if(not clip_guided):
        outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback)
    else:
        extra_args = {'clip_embed': clip_embed}
        cond_fn_ = cond_fn
        model_fn = make_cond_model_fn(model, cond_fn_)
        outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
    images_out = []
    for i, out in enumerate(outs):
        images_out.append(utils.to_pil_image(out))
    return(images_out)
    

##################### START GRADIO HERE ############################
gallery = gr.outputs.Carousel(label="Individual images",components=["image"])
iface = gr.Interface(
    fn=run_all, 
    inputs=[
    gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="an eerie alien forest"),
    gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=40,maximum=80,minimum=1,step=1),
    gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1, step=1),
    gr.inputs.Slider(label="Weight - how closely the image should resemble the prompt", default=5, maximum=15, minimum=0, step=1),
    gr.inputs.Checkbox(label="CLIP Guided - improves coherence with complex prompts, makes it slower (with CLIP Guidance only one image is generated)"),
    ], 
    outputs=gallery,
    title="Generate images from text with V-Diffusion",
    description="<div>By typing a prompt and pressing submit you can generate images based on this prompt. <a href='https://github.com/crowsonkb/v-diffusion-pytorch' target='_blank'>V-Diffusion</a> is diffusion text-to-image model created by <a href='https://twitter.com/RiversHaveWings' target='_blank'>Katherine Crowson</a> and <a href='https://twitter.com/jd_pressman'>JDP</a>, trained on the <a href='https://github.com/google-research-datasets/conceptual-12m'>CC12M dataset</a>. The UI to the model was assembled by <a style='color: rgb(99, 102, 241);font-weight:bold' href='https://twitter.com/multimodalart' target='_blank'>@multimodalart</a>, keep up with the <a style='color: rgb(99, 102, 241);' href='https://multimodal.art/news' target='_blank'>latest multimodal ai art news here</a> and consider <a style='color: rgb(99, 102, 241);' href='https://www.patreon.com/multimodalart' target='_blank'>supporting us on Patreon</a></div>",
    )
iface.launch(enable_queue=True)