sxela's picture
bump model
a88186d
import os
import sys
import gradio as gr
os.system('git clone https://github.com/openai/CLIP')
os.system('git clone https://github.com/crowsonkb/guided-diffusion')
os.system('pip install -e ./CLIP')
os.system('pip install -e ./guided-diffusion')
os.system('pip install lpips')
os.system("curl -OL 'https://github.com/Sxela/DiscoDiffusion-Warp/releases/download/v0.1.1/256x256_openai_comics_faces_v2.by_alex_spirin_114k.pt'")
import io
import math
import sys
import lpips
from PIL import Image
import requests
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
sys.path.append('./CLIP')
sys.path.append('./guided-diffusion')
import clip
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
import numpy as np
import imageio
torch.hub.download_url_to_file('https://images.pexels.com/photos/68767/divers-underwater-ocean-swim-68767.jpeg', 'face.jpeg')
def fetch(url_or_path):
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
r = requests.get(url_or_path)
r.raise_for_status()
fd = io.BytesIO()
fd.write(r.content)
fd.seek(0)
return fd
return open(url_or_path, 'rb')
def parse_prompt(prompt):
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 2)
vals = [vals[0] + ':' + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(':', 1)
vals = vals + ['', '1'][len(vals):]
return vals[0], float(vals[1])
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.cut_pow = cut_pow
def forward(self, input):
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(self.cutn):
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
return torch.cat(cutouts)
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
def tv_loss(input):
"""L2 total variation loss, as in Mahendran et al."""
input = F.pad(input, (0, 1, 0, 1), 'replicate')
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
return (x_diff**2 + y_diff**2).mean([1, 2, 3])
def range_loss(input):
return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn, im_prompt_weight):
# Model settings
skip_timesteps = min(skip_timesteps, timestep_respacing-1)
skip_timesteps = int(timestep_respacing-1 - (timestep_respacing-1)*skip_timesteps/100)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model_config = model_and_diffusion_defaults()
model_config.update({
'attention_resolutions': '16',
'class_cond': False,
'diffusion_steps': 1000,
'rescale_timesteps': True,
'timestep_respacing': str(timestep_respacing),
'image_size': 256,
'learn_sigma': True,
'noise_schedule': 'linear',
'num_channels': 128,
'num_heads': 1,
'num_res_blocks': 2,
'use_checkpoint': True,
'use_fp16': False if device.type == 'cpu' else True,
'use_scale_shift_norm': False,
})
# Load models
print('Using fp16: ',model_config['use_fp16'])
print('Using device:', device)
model, diffusion = create_model_and_diffusion(**model_config)
model.load_state_dict(torch.load('256x256_openai_comics_faces_v2.by_alex_spirin_114k.pt', map_location='cpu'))
model.requires_grad_(False).eval().to(device).float()
for name, param in model.named_parameters():
if 'qkv' in name or 'norm' in name or 'proj' in name:
param.requires_grad_()
if model_config['use_fp16']:
model.convert_to_fp16()
else: model.convert_to_fp32()
clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device).float()
clip_size = clip_model.visual.input_resolution
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
all_frames = []
prompts = [text]
batch_size = 1
clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
tv_scale = tv_scale # Controls the smoothness of the final output.
range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
cutn = cutn
n_batches = 1
skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image.
# Higher values make the output look more like the init.
init_scale = init_scale # This enhances the effect of the init image, a good value is 1000.
seed = seed
if seed is not None:
torch.manual_seed(seed)
make_cutouts = MakeCutouts(clip_size, cutn)
side_x = side_y = model_config['image_size']
target_embeds, weights = [], []
for prompt in prompts:
txt, weight = parse_prompt(prompt)
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
weights.append(weight)
if image_prompts is not None:
img = Image.fromarray(image_prompts).convert('RGB')
img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
embed = clip_model.encode_image(normalize(batch)).float()
target_embeds.append(embed)
weights.extend([im_prompt_weight / cutn] * cutn)
target_embeds = torch.cat(target_embeds)
weights = torch.tensor(weights, device=device)
if weights.sum().abs() < 1e-3:
raise RuntimeError('The weights must not sum to 0.')
weights /= weights.sum().abs()
init = None
if init_image is not None:
lpips_model = lpips.LPIPS(net='vgg').to(device)
init = Image.fromarray(init_image).convert('RGB')
init = init.resize((side_x, side_y), Image.LANCZOS)
init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
else: skip_timesteps = 0
cur_t = None
def cond_fn(x, t, y=None):
with torch.enable_grad():
x = x.detach().requires_grad_()
n = x.shape[0]
my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})
fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
x_in = out['pred_xstart'] * fac + x * (1 - fac)
clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
image_embeds = clip_model.encode_image(clip_in).float()
dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
dists = dists.view([cutn, n, -1])
losses = dists.mul(weights).sum(2).mean(0)
tv_losses = tv_loss(x_in)
range_losses = range_loss(out['pred_xstart'])
loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
if init is not None and init_scale:
init_losses = lpips_model(x_in, init)
loss = loss + init_losses.sum() * init_scale
return -torch.autograd.grad(loss, x)[0]
if model_config['timestep_respacing'].startswith('ddim'):
sample_fn = diffusion.ddim_sample_loop_progressive
else:
sample_fn = diffusion.p_sample_loop_progressive
for i in range(n_batches):
cur_t = diffusion.num_timesteps - skip_timesteps - 1
samples = sample_fn(
model,
(batch_size, 3, side_y, side_x),
clip_denoised=False,
model_kwargs={},
cond_fn=cond_fn,
progress=True,
skip_timesteps=skip_timesteps,
init_image=init,
randomize_class=True,
)
for j, sample in enumerate(samples):
cur_t -= 1
if j % 1 == 0 or cur_t == -1:
print()
for k, image in enumerate(sample['pred_xstart']):
img = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
all_frames.append(img)
tqdm.write(f'Batch {i}, step {j}, output {k}:')
writer = imageio.get_writer('video.mp4', fps=5)
for im in all_frames:
writer.append_data(np.array(im))
writer.close()
return img, 'video.mp4'
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# CLIP Guided Openai Diffusion Faces Model
### by [Alex Spirin](https://linktr.ee/devdef)
Gradio Blocks demo for CLIP Guided Diffusion. To use it, simply add your text, or click one of the examples to load them.
Based on the original [Space](https://huggingface.co/spaces/EleutherAI/clip-guided-diffusion) by akhaliq.
![visitors](https://visitor-badge.glitch.me/badge?page_id=sxela_dd_custom_model_hf_space)
""")
with gr.Row():
text = gr.Textbox(placeholder="Enter a description of a face", label='Text prompt', value="A beautiful girl by Greg Rutkowski")
with gr.Tabs():
with gr.TabItem("Settings"):
with gr.Row():
# with gr.Group():
with gr.Column():
clip_guidance_scale = gr.Slider(minimum=0, maximum=3000, step=1, value=600, label="Prompt strength")
tv_scale = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Smoothness")
range_scale = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Compress color range")
# with gr.Group():
with gr.Column():
timestep_respacing = gr.Slider(minimum=25, maximum=100, step=1, value=25, label="Timestep respacing")
cutn = gr.Slider(minimum=4, maximum=32, step=1, value=16, label="cutn")
seed = gr.Number(value=0, label="Seed")
with gr.TabItem("Input images"):
with gr.Row():
# with gr.Group():
with gr.Column():
init_image = gr.Image(source="upload", label='initial image (optional)')
init_scale = gr.Slider(minimum=0, maximum=1000, step=10, value=0, label="Look like the image above")
skip_timesteps = gr.Slider(minimum=0, maximum=100, step=1, value=30, label="Style strength, % (0 = initial image)")
# with gr.Group():
with gr.Column():
image_prompts = gr.Image(source="upload", label='image prompt (optional)')
im_prompt_weight = gr.Slider(minimum=0, maximum=10, step=1, value=1, label="Look like the image above")
with gr.Group():
with gr.Row():
gr.Markdown(
"""
### Press Run to Run :D
----
""")
with gr.Row():
run_button = gr.Button("Run!")
with gr.Row():
gr.Markdown(
"""
### Results
---
""")
with gr.Row():
output_image = gr.Image(label='Output image', type='numpy')
output_video = gr.Video(label='Output video')
outputs=[output_image,output_video]
run_button.click(inference, inputs=[text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn, im_prompt_weight], outputs=outputs)
demo.launch(enable_queue=True)