File size: 3,594 Bytes
36071c0
 
 
 
399b4c2
 
36071c0
 
50774a6
36071c0
69e62cf
36071c0
 
 
 
 
 
 
 
 
 
 
 
2e506d2
36071c0
50774a6
36071c0
18a451e
36071c0
 
 
 
 
 
 
 
 
 
 
404e08e
36071c0
 
 
 
 
 
18a451e
ca173e0
 
5ddf66b
 
18a451e
 
2e506d2
18a451e
 
 
36071c0
 
e3b915c
 
 
 
 
 
 
50774a6
b57dc0e
 
e3b915c
36071c0
 
e3b915c
 
 
 
b57dc0e
e3b915c
36071c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import tensorflow as tf
import huggingface_hub as hf_hub
import gradio as gr

num_rows = 2
num_cols = 4
num_images = num_rows * num_cols
image_size = 64
plot_image_size = 128

model = hf_hub.from_pretrained_keras("keras-io/denoising-diffusion-models")

def diffusion_schedule(diffusion_times, min_signal_rate, max_signal_rate):
    start_angle = tf.acos(max_signal_rate)
    end_angle = tf.acos(min_signal_rate)

    diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)

    signal_rates = tf.cos(diffusion_angles)
    noise_rates = tf.sin(diffusion_angles)
    
    return noise_rates, signal_rates

def generate_images(diffusion_steps, stochasticity, min_signal_rate, max_signal_rate):
    step_size = 1.0 / diffusion_steps
    initial_noise = tf.random.normal(shape=(num_images, image_size, image_size, 3))
    
    # reverse diffusion
    noisy_images = initial_noise
    for step in range(diffusion_steps):
        diffusion_times = tf.ones((num_images, 1, 1, 1)) - step * step_size
        next_diffusion_times = diffusion_times - step_size
        
        noise_rates, signal_rates = diffusion_schedule(diffusion_times, min_signal_rate, max_signal_rate)
        next_noise_rates, next_signal_rates = diffusion_schedule(next_diffusion_times, min_signal_rate, max_signal_rate)
        
        sample_noises = tf.random.normal(shape=(num_images, image_size, image_size, 3))
        sample_noise_rates = stochasticity * (1.0 - (signal_rates / next_signal_rates)**2)**0.5 * (next_noise_rates / noise_rates)
        
        pred_noises, pred_images = model([noisy_images, noise_rates, signal_rates])
        noisy_images = (
            next_signal_rates * pred_images
            + (next_noise_rates**2 - sample_noise_rates**2)**0.5 * pred_noises
            + sample_noise_rates * sample_noises
        )
        
    # denormalize
    data_mean = tf.constant([[[[0.4705, 0.3943, 0.3033]]]])
    data_std_dev = tf.constant([[[[0.2892, 0.2364, 0.2680]]]])
    generated_images = data_mean + pred_images * data_std_dev
    generated_images = tf.clip_by_value(generated_images, 0.0, 1.0)
    
    # make grid
    generated_images = tf.image.resize(generated_images, (plot_image_size, plot_image_size), method="nearest")
    generated_images = tf.reshape(generated_images, (num_rows, num_cols, plot_image_size, plot_image_size, 3))
    generated_images = tf.transpose(generated_images, (0, 2, 1, 3, 4))
    generated_images = tf.reshape(generated_images, (num_rows * plot_image_size, num_cols * plot_image_size, 3))
    return generated_images.numpy()

inputs = [
    gr.inputs.Slider(1, 20, step=1, default=10, label="Diffusion steps"),
    gr.inputs.Slider(0.0, 1.0, step=0.05, default=0.0, label="Stochasticity (η in the paper)"),
    gr.inputs.Slider(0.02, 0.10, step=0.01, default=0.02, label="Minimal signal rate"),
    gr.inputs.Slider(0.80, 0.95, step=0.01, default=0.95, label="Maximal signal rate"),
]
output = gr.outputs.Image(label="Generated images")
examples = [[3, 0.0, 0.02, 0.95], [10, 0.0, 0.02, 0.95], [20, 1.0, 0.02, 0.95]]
title = "Denoising Diffusion Implicit Models 🌹💨"
description = "Generating images with a denoising diffusion implicit model, trained on the Oxford Flowers dataset."
article = "<div style='text-align: center;'>Keras code example and demo by <a href='https://www.linkedin.com/in/andras-beres-789190210' target='_blank'>András Béres</a></div>"
gr.Interface(
    generate_images,
    inputs=inputs,
    outputs=output,
    examples=examples,
    title=title,
    description=description,
    article=article,
).launch()