File size: 3,915 Bytes
c4a326d
 
 
0f3dbed
 
 
 
 
 
c4a326d
0f3dbed
c4a326d
 
0f3dbed
 
 
c4a326d
0f3dbed
1b928b3
 
 
 
 
 
 
c4a326d
0f3dbed
 
cd6cfc3
 
 
 
 
 
 
 
 
0f3dbed
 
 
 
c4a326d
0f3dbed
1b928b3
0f3dbed
 
 
 
 
c4a326d
1b928b3
 
c4a326d
1b928b3
0f3dbed
1b928b3
0f3dbed
1b928b3
 
 
c4a326d
0f3dbed
 
1b928b3
0f3dbed
 
 
1b928b3
 
0f3dbed
1b928b3
0f3dbed
1b928b3
0f3dbed
1b928b3
 
 
 
0f3dbed
 
1b928b3
0f3dbed
 
cd6cfc3
 
 
 
0f3dbed
1b928b3
 
 
c4a326d
1b928b3
 
 
 
 
 
 
 
 
 
 
8dd5003
 
 
1b928b3
 
c4a326d
1b928b3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import numpy as np
import torch
from diffusers import DDPMPipeline, DDIMScheduler
import open_clip
import torchvision
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F

# Initialize device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CLIP model
clip_model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
clip_model.to(device)

# Transform to preprocess images
tfms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.Normalize(
        mean=(0.48145466, 0.4578275, 0.40821073),
        std=(0.26862954, 0.26130258, 0.27577711),
    ),
])

# CLIP Loss function
def clip_loss(image, text_features):
    # Ensure image is in the correct format (B, C, H, W)
    if image.dim() == 3:
        image = image.unsqueeze(0)
    
    # Apply transforms
    image = tfms(image)
    
    # Encode image
    image_features = clip_model.encode_image(image)
    image_features = F.normalize(image_features, dim=-1)
    text_features = F.normalize(text_features, dim=-1)
    loss = (1 - torch.cosine_similarity(image_features, text_features)).mean()
    return loss

# Load Diffusion model
model_repo_id = "muneebable/ddpm-celebahq-finetuned-anime-art"
image_pipe = DDPMPipeline.from_pretrained(model_repo_id)
image_pipe.to(device)

# Load scheduler
scheduler = DDIMScheduler.from_pretrained(model_repo_id)

def generate_image(prompt, guidance_scale, num_steps):
    scheduler.set_timesteps(num_inference_steps=num_steps)
    
    # We embed a prompt with CLIP as our target
    text = open_clip.tokenize([prompt]).to(device)
    with torch.no_grad(), torch.cuda.amp.autocast():
        text_features = clip_model.encode_text(text)
    
    x = torch.randn(1, 3, 256, 256).to(device)
    n_cuts = 4
    
    for i, t in tqdm(enumerate(scheduler.timesteps)):
        model_input = scheduler.scale_model_input(x, t)
        # predict the noise residual
        with torch.no_grad():
            noise_pred = image_pipe.unet(model_input, t)["sample"]
        cond_grad = 0
        for cut in range(n_cuts):
            # Set requires grad on x
            x = x.detach().requires_grad_()
            # Get the predicted x0:
            x0 = scheduler.step(noise_pred, t, x).pred_original_sample
            # Calculate loss
            loss = clip_loss(x0, text_features) * guidance_scale
            # Get gradient (scale by n_cuts since we want the average)
            cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts
        
        # Modify x based on this gradient
        alpha_bar = scheduler.alphas_cumprod[i]
        x = x.detach() + cond_grad * alpha_bar.sqrt()
        # Now step with scheduler
        x = scheduler.step(noise_pred, t, x).prev_sample
    
    # Convert the tensor to a PIL Image
    x = x.squeeze(0).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
    x = (x * 255).byte().numpy()
    return Image.fromarray(x)

# Gradio interface
def gradio_interface(prompt, guidance_scale, num_steps):
    return generate_image(prompt, guidance_scale, num_steps)

iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Textbox(label="Prompt", value="Red Rose (still life), red flower painting"),
        gr.Slider(minimum=1, maximum=20, step=1, label="Guidance Scale", value=8),
        gr.Slider(minimum=10, maximum=100, step=10, label="Number of Steps", value=50)
    ],
    outputs=gr.Image(type="pil", label="Generated Image"),
    title="CLIP-Guided Diffusion Image Generation",
    description="Generate images using CLIP-guided diffusion. Enter a prompt, adjust the guidance scale, and set the number of steps.",
    examples=[
        ["A serene landscape with mountains and a lake", 10, 2],
        ["A futuristic cityscape at night", 15, 5],
        ["Red Rose (still life), red flower painting", 5, 5]
    ]
)

iface.launch()