diffusion / app.py
multimodalart's picture
Update app.py
25ed0da
raw
history blame
No virus
9.8 kB
import gc
import math
import sys
from IPython import display
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision import utils as tv_utils
from torchvision.transforms import functional as TF
import gradio as gr
from git.repo.base import Repo
from os.path import exists as path_exists
if not (path_exists(f"v-diffusion-pytorch")):
Repo.clone_from("https://github.com/crowsonkb/v-diffusion-pytorch", "v-diffusion-pytorch")
if not (path_exists(f"CLIP")):
Repo.clone_from("https://github.com/openai/CLIP", "CLIP")
sys.path.append('v-diffusion-pytorch')
from huggingface_hub import hf_hub_download
from CLIP import clip
from diffusion import get_model, sampling, utils
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.cut_pow = cut_pow
def forward(self, input):
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(self.cutn):
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
cutout = F.adaptive_avg_pool2d(cutout, self.cut_size)
cutouts.append(cutout)
return torch.cat(cutouts)
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
#cc12m_small_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1.pth")
model = get_model('cc12m_1_cfg')()
_, side_y, side_x = model.shape
model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
model = model.half().cuda().eval().requires_grad_(False)
#model_small = get_model('cc12m_1')()
#model_small.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
#model_small = model_small.half().cuda().eval().requires_grad_(False)
print(model.clip_model)
clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0]
clip_model.eval().requires_grad_(False)
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
make_cutouts = MakeCutouts(clip_model.visual.input_resolution, 16, 1.)
def run_all(prompt, steps, n_images, weight, clip_guided):
import random
seed = int(random.randint(0, 2147483647))
target_embed = clip_model.encode_text(clip.tokenize(prompt).to('cuda')).float()#.cuda()
if(clip_guided):
steps = steps*5
clip_guidance_scale = weight*100
prompts = [prompt]
target_embeds, weights = [], []
def parse_prompt(prompt):
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 2)
vals = [vals[0] + ':' + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(':', 1)
vals = vals + ['', '1'][len(vals):]
return vals[0], float(vals[1])
for prompt in prompts:
txt, weight = parse_prompt(prompt)
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to('cuda')).float())
weights.append(weight)
target_embeds = torch.cat(target_embeds)
weights = torch.tensor(weights, device='cuda')
if weights.sum().abs() < 1e-3:
raise RuntimeError('The weights must not sum to 0.')
weights /= weights.sum().abs()
clip_embed = F.normalize(target_embeds.mul(weights[:, None]).sum(0, keepdim=True), dim=-1)
clip_embed = target_embed.repeat([n_images, 1])
def cfg_model_fn(x, t):
"""The CFG wrapper function."""
n = x.shape[0]
x_in = x.repeat([2, 1, 1, 1])
t_in = t.repeat([2])
clip_embed_repeat = target_embed.repeat([n, 1])
clip_embed_in = torch.cat([torch.zeros_like(clip_embed_repeat), clip_embed_repeat])
v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0)
v = v_uncond + (v_cond - v_uncond) * weight
return v
def make_cond_model_fn(model, cond_fn):
def cond_model_fn(x, t, **extra_args):
with torch.enable_grad():
x = x.detach().requires_grad_()
v = model(x, t, **extra_args)
alphas, sigmas = utils.t_to_alpha_sigma(t)
pred = x * alphas[:, None, None, None] - v * sigmas[:, None, None, None]
cond_grad = cond_fn(x, t, pred, **extra_args).detach()
v = v.detach() - cond_grad * (sigmas[:, None, None, None] / alphas[:, None, None, None])
return v
return cond_model_fn
def cond_fn(x, t, pred, clip_embed):
if min(pred.shape[2:4]) < 256:
pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
clip_in = normalize(make_cutouts((pred + 1) / 2))
image_embeds = clip_model.encode_image(clip_in).view([16, x.shape[0], -1])
losses = spherical_dist_loss(image_embeds, clip_embed[None])
loss = losses.mean(0).sum() * clip_guidance_scale
grad = -torch.autograd.grad(loss, x)[0]
return grad
gc.collect()
torch.cuda.empty_cache()
torch.manual_seed(seed)
x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
#step_list = utils.get_spliced_ddpm_cosine_schedule(t)
if model.min_t == 0:
step_list = utils.get_spliced_ddpm_cosine_schedule(t)
else:
step_list = utils.get_ddpm_schedule(t)
if(not clip_guided):
outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback)
else:
extra_args = {'clip_embed': clip_embed}
cond_fn_ = cond_fn
model_fn = make_cond_model_fn(model, cond_fn_)
outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
images_out = []
for i, out in enumerate(outs):
images_out.append(utils.to_pil_image(out))
return(images_out)
##################### START GRADIO HERE ############################
gallery = gr.outputs.Carousel(label="Individual images",components=["image"])
iface = gr.Interface(
fn=run_all,
inputs=[
gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="an eerie alien forest"),
gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=40,maximum=80,minimum=1,step=1),
gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1, step=1),
gr.inputs.Slider(label="Weight - how closely the image should resemble the prompt", default=5, maximum=15, minimum=0, step=1),
gr.inputs.Checkbox(label="CLIP Guided - improves coherence with complex prompts, makes it slower"),
],
outputs=gallery,
title="Generate images from text with V-Diffusion",
description="<div>By typing a prompt and pressing submit you can generate images based on this prompt. <a href='https://github.com/crowsonkb/v-diffusion-pytorch' target='_blank'>V-Diffusion</a> is diffusion text-to-image model created by <a href='https://twitter.com/RiversHaveWings' target='_blank'>Katherine Crowson</a> and <a href='https://twitter.com/jd_pressman'>JDP</a>, trained on the <a href='https://github.com/google-research-datasets/conceptual-12m'>CC12M dataset</a>. The UI to the model was assembled by <a style='color: rgb(99, 102, 241);font-weight:bold' href='https://twitter.com/multimodalart' target='_blank'>@multimodalart</a>, keep up with the <a style='color: rgb(99, 102, 241);' href='https://multimodal.art/news' target='_blank'>latest multimodal ai art news here</a> and consider <a style='color: rgb(99, 102, 241);' href='https://www.patreon.com/multimodalart' target='_blank'>supporting us on Patreon</a></div>",
#article="<h4 style='font-size: 110%;margin-top:.5em'>Biases acknowledgment</h4><div>Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exarcbates societal biases. According to the <a href='https://arxiv.org/abs/2112.10752' target='_blank'>Latent Diffusion paper</a>:<i> \"Deep learning modules tend to reproduce or exacerbate biases that are already present in the data\"</i>. The model was trained on an unfiltered version the LAION-400M dataset, which scrapped non-curated image-text-pairs from the internet (the exception being the the removal of illegal content) and is meant to be used for research purposes, such as this one. <a href='https://laion.ai/laion-400-open-dataset/' target='_blank'>You can read more on LAION's website</a></div><h4 style='font-size: 110%;margin-top:1em'>Who owns the images produced by this demo?</h4><div>Definetly not me! Probably you do. I say probably because the Copyright discussion about AI generated art is ongoing. So <a href='https://www.theverge.com/2022/2/21/22944335/us-copyright-office-reject-ai-generated-art-recent-entrance-to-paradise' target='_blank'>it may be the case that everything produced here falls automatically into the public domain</a>. But in any case it is either yours or is in the public domain.</div>"
)
iface.launch(enable_queue=True)