diffusion / app.py
multimodalart's picture
Update app.py
25ed0da
import gc
import math
import sys
from IPython import display
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision import utils as tv_utils
from torchvision.transforms import functional as TF
import gradio as gr
from git.repo.base import Repo
from os.path import exists as path_exists
if not (path_exists(f"v-diffusion-pytorch")):
Repo.clone_from("https://github.com/crowsonkb/v-diffusion-pytorch", "v-diffusion-pytorch")
if not (path_exists(f"CLIP")):
Repo.clone_from("https://github.com/openai/CLIP", "CLIP")
sys.path.append('v-diffusion-pytorch')
from huggingface_hub import hf_hub_download
from CLIP import clip
from diffusion import get_model, sampling, utils
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.cut_pow = cut_pow
def forward(self, input):
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(self.cutn):
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
cutout = F.adaptive_avg_pool2d(cutout, self.cut_size)
cutouts.append(cutout)
return torch.cat(cutouts)
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
#cc12m_small_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1.pth")
model = get_model('cc12m_1_cfg')()
_, side_y, side_x = model.shape
model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
model = model.half().cuda().eval().requires_grad_(False)
#model_small = get_model('cc12m_1')()
#model_small.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
#model_small = model_small.half().cuda().eval().requires_grad_(False)
print(model.clip_model)
clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0]
clip_model.eval().requires_grad_(False)
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
make_cutouts = MakeCutouts(clip_model.visual.input_resolution, 16, 1.)
def run_all(prompt, steps, n_images, weight, clip_guided):
import random
seed = int(random.randint(0, 2147483647))
target_embed = clip_model.encode_text(clip.tokenize(prompt).to('cuda')).float()#.cuda()
if(clip_guided):
steps = steps*5
clip_guidance_scale = weight*100
prompts = [prompt]
target_embeds, weights = [], []
def parse_prompt(prompt):
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 2)
vals = [vals[0] + ':' + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(':', 1)
vals = vals + ['', '1'][len(vals):]
return vals[0], float(vals[1])
for prompt in prompts:
txt, weight = parse_prompt(prompt)
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to('cuda')).float())
weights.append(weight)
target_embeds = torch.cat(target_embeds)
weights = torch.tensor(weights, device='cuda')
if weights.sum().abs() < 1e-3:
raise RuntimeError('The weights must not sum to 0.')
weights /= weights.sum().abs()
clip_embed = F.normalize(target_embeds.mul(weights[:, None]).sum(0, keepdim=True), dim=-1)
clip_embed = target_embed.repeat([n_images, 1])
def cfg_model_fn(x, t):
"""The CFG wrapper function."""
n = x.shape[0]
x_in = x.repeat([2, 1, 1, 1])
t_in = t.repeat([2])
clip_embed_repeat = target_embed.repeat([n, 1])
clip_embed_in = torch.cat([torch.zeros_like(clip_embed_repeat), clip_embed_repeat])
v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0)
v = v_uncond + (v_cond - v_uncond) * weight
return v
def make_cond_model_fn(model, cond_fn):
def cond_model_fn(x, t, **extra_args):
with torch.enable_grad():
x = x.detach().requires_grad_()
v = model(x, t, **extra_args)
alphas, sigmas = utils.t_to_alpha_sigma(t)
pred = x * alphas[:, None, None, None] - v * sigmas[:, None, None, None]
cond_grad = cond_fn(x, t, pred, **extra_args).detach()
v = v.detach() - cond_grad * (sigmas[:, None, None, None] / alphas[:, None, None, None])
return v
return cond_model_fn
def cond_fn(x, t, pred, clip_embed):
if min(pred.shape[2:4]) < 256:
pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
clip_in = normalize(make_cutouts((pred + 1) / 2))
image_embeds = clip_model.encode_image(clip_in).view([16, x.shape[0], -1])
losses = spherical_dist_loss(image_embeds, clip_embed[None])
loss = losses.mean(0).sum() * clip_guidance_scale
grad = -torch.autograd.grad(loss, x)[0]
return grad
gc.collect()
torch.cuda.empty_cache()
torch.manual_seed(seed)
x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
#step_list = utils.get_spliced_ddpm_cosine_schedule(t)
if model.min_t == 0:
step_list = utils.get_spliced_ddpm_cosine_schedule(t)
else:
step_list = utils.get_ddpm_schedule(t)
if(not clip_guided):
outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback)
else:
extra_args = {'clip_embed': clip_embed}
cond_fn_ = cond_fn
model_fn = make_cond_model_fn(model, cond_fn_)
outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
images_out = []
for i, out in enumerate(outs):
images_out.append(utils.to_pil_image(out))
return(images_out)
##################### START GRADIO HERE ############################
gallery = gr.outputs.Carousel(label="Individual images",components=["image"])
iface = gr.Interface(
fn=run_all,
inputs=[
gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="an eerie alien forest"),
gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=40,maximum=80,minimum=1,step=1),
gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1, step=1),
gr.inputs.Slider(label="Weight - how closely the image should resemble the prompt", default=5, maximum=15, minimum=0, step=1),
gr.inputs.Checkbox(label="CLIP Guided - improves coherence with complex prompts, makes it slower"),
],
outputs=gallery,
title="Generate images from text with V-Diffusion",
description="<div>By typing a prompt and pressing submit you can generate images based on this prompt. <a href='https://github.com/crowsonkb/v-diffusion-pytorch' target='_blank'>V-Diffusion</a> is diffusion text-to-image model created by <a href='https://twitter.com/RiversHaveWings' target='_blank'>Katherine Crowson</a> and <a href='https://twitter.com/jd_pressman'>JDP</a>, trained on the <a href='https://github.com/google-research-datasets/conceptual-12m'>CC12M dataset</a>. The UI to the model was assembled by <a style='color: rgb(99, 102, 241);font-weight:bold' href='https://twitter.com/multimodalart' target='_blank'>@multimodalart</a>, keep up with the <a style='color: rgb(99, 102, 241);' href='https://multimodal.art/news' target='_blank'>latest multimodal ai art news here</a> and consider <a style='color: rgb(99, 102, 241);' href='https://www.patreon.com/multimodalart' target='_blank'>supporting us on Patreon</a></div>",
#article="<h4 style='font-size: 110%;margin-top:.5em'>Biases acknowledgment</h4><div>Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exarcbates societal biases. According to the <a href='https://arxiv.org/abs/2112.10752' target='_blank'>Latent Diffusion paper</a>:<i> \"Deep learning modules tend to reproduce or exacerbate biases that are already present in the data\"</i>. The model was trained on an unfiltered version the LAION-400M dataset, which scrapped non-curated image-text-pairs from the internet (the exception being the the removal of illegal content) and is meant to be used for research purposes, such as this one. <a href='https://laion.ai/laion-400-open-dataset/' target='_blank'>You can read more on LAION's website</a></div><h4 style='font-size: 110%;margin-top:1em'>Who owns the images produced by this demo?</h4><div>Definetly not me! Probably you do. I say probably because the Copyright discussion about AI generated art is ongoing. So <a href='https://www.theverge.com/2022/2/21/22944335/us-copyright-office-reject-ai-generated-art-recent-entrance-to-paradise' target='_blank'>it may be the case that everything produced here falls automatically into the public domain</a>. But in any case it is either yours or is in the public domain.</div>"
)
iface.launch(enable_queue=True)