Spaces:
Sleeping
Sleeping
File size: 3,915 Bytes
c4a326d 0f3dbed c4a326d 0f3dbed c4a326d 0f3dbed c4a326d 0f3dbed 1b928b3 c4a326d 0f3dbed cd6cfc3 0f3dbed c4a326d 0f3dbed 1b928b3 0f3dbed c4a326d 1b928b3 c4a326d 1b928b3 0f3dbed 1b928b3 0f3dbed 1b928b3 c4a326d 0f3dbed 1b928b3 0f3dbed 1b928b3 0f3dbed 1b928b3 0f3dbed 1b928b3 0f3dbed 1b928b3 0f3dbed 1b928b3 0f3dbed cd6cfc3 0f3dbed 1b928b3 c4a326d 1b928b3 8dd5003 1b928b3 c4a326d 1b928b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
import numpy as np
import torch
from diffusers import DDPMPipeline, DDIMScheduler
import open_clip
import torchvision
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
# Initialize device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load CLIP model
clip_model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
clip_model.to(device)
# Transform to preprocess images
tfms = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
])
# CLIP Loss function
def clip_loss(image, text_features):
# Ensure image is in the correct format (B, C, H, W)
if image.dim() == 3:
image = image.unsqueeze(0)
# Apply transforms
image = tfms(image)
# Encode image
image_features = clip_model.encode_image(image)
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
loss = (1 - torch.cosine_similarity(image_features, text_features)).mean()
return loss
# Load Diffusion model
model_repo_id = "muneebable/ddpm-celebahq-finetuned-anime-art"
image_pipe = DDPMPipeline.from_pretrained(model_repo_id)
image_pipe.to(device)
# Load scheduler
scheduler = DDIMScheduler.from_pretrained(model_repo_id)
def generate_image(prompt, guidance_scale, num_steps):
scheduler.set_timesteps(num_inference_steps=num_steps)
# We embed a prompt with CLIP as our target
text = open_clip.tokenize([prompt]).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = clip_model.encode_text(text)
x = torch.randn(1, 3, 256, 256).to(device)
n_cuts = 4
for i, t in tqdm(enumerate(scheduler.timesteps)):
model_input = scheduler.scale_model_input(x, t)
# predict the noise residual
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
cond_grad = 0
for cut in range(n_cuts):
# Set requires grad on x
x = x.detach().requires_grad_()
# Get the predicted x0:
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# Calculate loss
loss = clip_loss(x0, text_features) * guidance_scale
# Get gradient (scale by n_cuts since we want the average)
cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts
# Modify x based on this gradient
alpha_bar = scheduler.alphas_cumprod[i]
x = x.detach() + cond_grad * alpha_bar.sqrt()
# Now step with scheduler
x = scheduler.step(noise_pred, t, x).prev_sample
# Convert the tensor to a PIL Image
x = x.squeeze(0).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
x = (x * 255).byte().numpy()
return Image.fromarray(x)
# Gradio interface
def gradio_interface(prompt, guidance_scale, num_steps):
return generate_image(prompt, guidance_scale, num_steps)
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(label="Prompt", value="Red Rose (still life), red flower painting"),
gr.Slider(minimum=1, maximum=20, step=1, label="Guidance Scale", value=8),
gr.Slider(minimum=10, maximum=100, step=10, label="Number of Steps", value=50)
],
outputs=gr.Image(type="pil", label="Generated Image"),
title="CLIP-Guided Diffusion Image Generation",
description="Generate images using CLIP-guided diffusion. Enter a prompt, adjust the guidance scale, and set the number of steps.",
examples=[
["A serene landscape with mountains and a lake", 10, 2],
["A futuristic cityscape at night", 15, 5],
["Red Rose (still life), red flower painting", 5, 5]
]
)
iface.launch() |