Spaces:
Runtime error
Runtime error
Esmail-AGumaan
commited on
Commit
•
64e1ee8
1
Parent(s):
33e1004
Upload 13 files
Browse files- attention.py +77 -0
- clip.py +64 -0
- ddpm.py +112 -0
- decoder.py +100 -0
- demo.ipynb +0 -0
- diffusion.py +213 -0
- encoder.py +56 -0
- gaussing_diffusion.py +814 -0
- model_converter.py +0 -0
- model_loader.py +28 -0
- pipeline.py +141 -0
- sd_gradio.py +69 -0
- sd_inference.py +72 -0
attention.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
class SelfAttention(nn.Module):
|
7 |
+
def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
|
8 |
+
super().__init__()
|
9 |
+
self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
|
10 |
+
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
|
11 |
+
self.n_heads = n_heads
|
12 |
+
self.d_head = d_embed // n_heads
|
13 |
+
|
14 |
+
def forward(self, x, causal_mask=False):
|
15 |
+
input_shape = x.shape
|
16 |
+
batch_size, sequence_length, d_embed = input_shape
|
17 |
+
interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
|
18 |
+
|
19 |
+
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
20 |
+
q = q.view(interim_shape).transpose(1, 2)
|
21 |
+
k = k.view(interim_shape).transpose(1, 2)
|
22 |
+
v = v.view(interim_shape).transpose(1, 2)
|
23 |
+
|
24 |
+
weight = q @ k.transpose(-1, -2)
|
25 |
+
|
26 |
+
if causal_mask:
|
27 |
+
mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
|
28 |
+
weight.masked_fill_(mask, -torch.inf)
|
29 |
+
|
30 |
+
weight /= math.sqrt(self.d_head)
|
31 |
+
|
32 |
+
weight = F.softmax(weight, dim=-1)
|
33 |
+
output = weight @ v
|
34 |
+
output = output.transpose(1, 2)
|
35 |
+
|
36 |
+
output = output.reshape(input_shape)
|
37 |
+
output = self.out_proj(output)
|
38 |
+
return output
|
39 |
+
|
40 |
+
class CrossAttention(nn.Module):
|
41 |
+
def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
|
42 |
+
super().__init__()
|
43 |
+
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
|
44 |
+
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
|
45 |
+
self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
|
46 |
+
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
|
47 |
+
self.n_heads = n_heads
|
48 |
+
self.d_head = d_embed // n_heads
|
49 |
+
|
50 |
+
def forward(self, x, y):
|
51 |
+
input_shape = x.shape
|
52 |
+
batch_size, sequence_length, d_embed = input_shape
|
53 |
+
interim_shape = (batch_size, -1, self.n_heads, self.d_head)
|
54 |
+
|
55 |
+
q = self.q_proj(x)
|
56 |
+
k = self.k_proj(y)
|
57 |
+
v = self.v_proj(y)
|
58 |
+
|
59 |
+
q = q.view(interim_shape).transpose(1, 2)
|
60 |
+
k = k.view(interim_shape).transpose(1, 2)
|
61 |
+
v = v.view(interim_shape).transpose(1, 2)
|
62 |
+
|
63 |
+
weight = q @ k.transpose(-1, -2)
|
64 |
+
|
65 |
+
weight /= math.sqrt(self.d_head)
|
66 |
+
|
67 |
+
weight = F.softmax(weight, dim=-1)
|
68 |
+
|
69 |
+
output = weight @ v
|
70 |
+
|
71 |
+
output = output.transpose(1, 2).contiguous()
|
72 |
+
|
73 |
+
output = output.view(input_shape)
|
74 |
+
|
75 |
+
output = self.out_proj(output)
|
76 |
+
|
77 |
+
return output
|
clip.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from nanograd.models.stable_diffusion.attention import SelfAttention
|
5 |
+
|
6 |
+
class CLIPEmbedding(nn.Module):
|
7 |
+
def __init__(self, n_vocab: int, n_embd: int, n_token: int):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.token_embedding = nn.Embedding(n_vocab, n_embd)
|
11 |
+
self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))
|
12 |
+
|
13 |
+
def forward(self, tokens):
|
14 |
+
x = self.token_embedding(tokens)
|
15 |
+
x += self.position_embedding
|
16 |
+
|
17 |
+
return x
|
18 |
+
|
19 |
+
class CLIPLayer(nn.Module):
|
20 |
+
def __init__(self, n_head: int, n_embd: int):
|
21 |
+
super().__init__()
|
22 |
+
self.layernorm_1 = nn.LayerNorm(n_embd)
|
23 |
+
self.attention = SelfAttention(n_head, n_embd)
|
24 |
+
self.layernorm_2 = nn.LayerNorm(n_embd)
|
25 |
+
self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
|
26 |
+
self.linear_2 = nn.Linear(4 * n_embd, n_embd)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
residue = x
|
30 |
+
x = self.layernorm_1(x)
|
31 |
+
x = self.attention(x, causal_mask=True)
|
32 |
+
x += residue
|
33 |
+
|
34 |
+
residue = x
|
35 |
+
x = self.layernorm_2(x)
|
36 |
+
x = self.linear_1(x)
|
37 |
+
|
38 |
+
x = x * torch.sigmoid(1.702 * x)
|
39 |
+
x = self.linear_2(x)
|
40 |
+
x += residue
|
41 |
+
|
42 |
+
return x
|
43 |
+
|
44 |
+
class CLIP(nn.Module):
|
45 |
+
def __init__(self):
|
46 |
+
super().__init__()
|
47 |
+
self.embedding = CLIPEmbedding(49408, 768, 77)
|
48 |
+
|
49 |
+
self.layers = nn.ModuleList([
|
50 |
+
CLIPLayer(12, 768) for i in range(12)
|
51 |
+
])
|
52 |
+
|
53 |
+
self.layernorm = nn.LayerNorm(768)
|
54 |
+
|
55 |
+
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
56 |
+
tokens = tokens.type(torch.long)
|
57 |
+
|
58 |
+
state = self.embedding(tokens)
|
59 |
+
|
60 |
+
for layer in self.layers:
|
61 |
+
state = layer(state)
|
62 |
+
output = self.layernorm(state)
|
63 |
+
|
64 |
+
return output
|
ddpm.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class DDPMSampler: # Denoising Diffusion Probabilistic Models Sampler
|
5 |
+
|
6 |
+
def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
|
7 |
+
# Params "beta_start" and "beta_end" taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8
|
8 |
+
# For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)
|
9 |
+
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
|
10 |
+
self.alphas = 1.0 - self.betas
|
11 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
12 |
+
self.one = torch.tensor(1.0)
|
13 |
+
|
14 |
+
self.generator = generator
|
15 |
+
|
16 |
+
self.num_train_timesteps = num_training_steps
|
17 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
|
18 |
+
|
19 |
+
def set_inference_timesteps(self, num_inference_steps=50):
|
20 |
+
self.num_inference_steps = num_inference_steps
|
21 |
+
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
22 |
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
23 |
+
self.timesteps = torch.from_numpy(timesteps)
|
24 |
+
|
25 |
+
def _get_previous_timestep(self, timestep: int) -> int:
|
26 |
+
prev_t = timestep - self.num_train_timesteps // self.num_inference_steps
|
27 |
+
return prev_t
|
28 |
+
|
29 |
+
def _get_variance(self, timestep: int) -> torch.Tensor:
|
30 |
+
prev_t = self._get_previous_timestep(timestep)
|
31 |
+
|
32 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
33 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
|
34 |
+
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
|
35 |
+
|
36 |
+
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
|
37 |
+
|
38 |
+
variance = torch.clamp(variance, min=1e-20)
|
39 |
+
|
40 |
+
return variance
|
41 |
+
|
42 |
+
def set_strength(self, strength=1):
|
43 |
+
"""
|
44 |
+
Set how much noise to add to the input image.
|
45 |
+
More noise (strength ~ 1) means that the output will be further from the input image.
|
46 |
+
Less noise (strength ~ 0) means that the output will be closer to the input image.
|
47 |
+
"""
|
48 |
+
# start_step is the number of noise levels to skip
|
49 |
+
start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
|
50 |
+
self.timesteps = self.timesteps[start_step:]
|
51 |
+
self.start_step = start_step
|
52 |
+
|
53 |
+
def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
|
54 |
+
t = timestep
|
55 |
+
prev_t = self._get_previous_timestep(t)
|
56 |
+
|
57 |
+
alpha_prod_t = self.alphas_cumprod[t]
|
58 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
|
59 |
+
beta_prod_t = 1 - alpha_prod_t
|
60 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
61 |
+
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
|
62 |
+
current_beta_t = 1 - current_alpha_t
|
63 |
+
|
64 |
+
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
65 |
+
pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
66 |
+
|
67 |
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
68 |
+
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
|
69 |
+
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
|
70 |
+
|
71 |
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
72 |
+
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
|
73 |
+
|
74 |
+
variance = 0
|
75 |
+
if t > 0:
|
76 |
+
device = model_output.device
|
77 |
+
noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
|
78 |
+
# Compute the variance as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
79 |
+
variance = (self._get_variance(t) ** 0.5) * noise
|
80 |
+
|
81 |
+
pred_prev_sample = pred_prev_sample + variance
|
82 |
+
|
83 |
+
return pred_prev_sample
|
84 |
+
|
85 |
+
def add_noise(
|
86 |
+
self,
|
87 |
+
original_samples: torch.FloatTensor,
|
88 |
+
timesteps: torch.IntTensor,
|
89 |
+
) -> torch.FloatTensor:
|
90 |
+
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
91 |
+
timesteps = timesteps.to(original_samples.device)
|
92 |
+
|
93 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
94 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
95 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
96 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
97 |
+
|
98 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
99 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
100 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
101 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
102 |
+
|
103 |
+
# Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
|
104 |
+
# Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
|
105 |
+
# here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
|
106 |
+
noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
|
107 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
108 |
+
return noisy_samples
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
decoder.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from nanograd.models.stable_diffusion.attention import SelfAttention
|
5 |
+
|
6 |
+
class VAE_AttentionBlock(nn.Module):
|
7 |
+
def __init__(self, channels):
|
8 |
+
super().__init__()
|
9 |
+
self.groupnorm = nn.GroupNorm(32, channels)
|
10 |
+
self.attention = SelfAttention(1, channels)
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
residue = x
|
14 |
+
x = self.groupnorm(x)
|
15 |
+
n, c, h, w = x.shape
|
16 |
+
x = x.view((n, c, h * w))
|
17 |
+
x = x.transpose(-1, -2)
|
18 |
+
x = self.attention(x)
|
19 |
+
x = x.transpose(-1, -2)
|
20 |
+
x = x.view((n, c, h, w))
|
21 |
+
x += residue
|
22 |
+
|
23 |
+
return x
|
24 |
+
|
25 |
+
class VAE_ResidualBlock(nn.Module):
|
26 |
+
def __init__(self, in_channels, out_channels):
|
27 |
+
super().__init__()
|
28 |
+
self.groupnorm_1 = nn.GroupNorm(32, in_channels)
|
29 |
+
self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
30 |
+
|
31 |
+
self.groupnorm_2 = nn.GroupNorm(32, out_channels)
|
32 |
+
self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
33 |
+
|
34 |
+
if in_channels == out_channels:
|
35 |
+
self.residual_layer = nn.Identity()
|
36 |
+
else:
|
37 |
+
self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
residue = x
|
41 |
+
x = self.groupnorm_1(x)
|
42 |
+
x = F.silu(x)
|
43 |
+
x = self.conv_1(x)
|
44 |
+
x = self.groupnorm_2(x)
|
45 |
+
x = F.silu(x)
|
46 |
+
x = self.conv_2(x)
|
47 |
+
|
48 |
+
return x + self.residual_layer(residue)
|
49 |
+
|
50 |
+
class VAE_Decoder(nn.Sequential):
|
51 |
+
def __init__(self):
|
52 |
+
super().__init__(
|
53 |
+
nn.Conv2d(4, 4, kernel_size=1, padding=0),
|
54 |
+
nn.Conv2d(4, 512, kernel_size=3, padding=1),
|
55 |
+
VAE_ResidualBlock(512, 512),
|
56 |
+
VAE_AttentionBlock(512),
|
57 |
+
VAE_ResidualBlock(512, 512),
|
58 |
+
VAE_ResidualBlock(512, 512),
|
59 |
+
VAE_ResidualBlock(512, 512),
|
60 |
+
VAE_ResidualBlock(512, 512),
|
61 |
+
|
62 |
+
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 4, Width / 4)
|
63 |
+
nn.Upsample(scale_factor=2),
|
64 |
+
|
65 |
+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
66 |
+
|
67 |
+
VAE_ResidualBlock(512, 512),
|
68 |
+
VAE_ResidualBlock(512, 512),
|
69 |
+
VAE_ResidualBlock(512, 512),
|
70 |
+
|
71 |
+
# (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 2, Width / 2)
|
72 |
+
nn.Upsample(scale_factor=2),
|
73 |
+
|
74 |
+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
75 |
+
|
76 |
+
VAE_ResidualBlock(512, 256),
|
77 |
+
VAE_ResidualBlock(256, 256),
|
78 |
+
VAE_ResidualBlock(256, 256),
|
79 |
+
|
80 |
+
nn.Upsample(scale_factor=2),
|
81 |
+
|
82 |
+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
83 |
+
|
84 |
+
VAE_ResidualBlock(256, 128),
|
85 |
+
VAE_ResidualBlock(128, 128),
|
86 |
+
VAE_ResidualBlock(128, 128),
|
87 |
+
|
88 |
+
nn.GroupNorm(32, 128),
|
89 |
+
|
90 |
+
nn.SiLU(),
|
91 |
+
|
92 |
+
nn.Conv2d(128, 3, kernel_size=3, padding=1),
|
93 |
+
)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
x /= 0.18215
|
97 |
+
|
98 |
+
for module in self:
|
99 |
+
x = module(x)
|
100 |
+
return x
|
demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diffusion.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from nanograd.models.stable_diffusion.attention import SelfAttention, CrossAttention
|
5 |
+
|
6 |
+
class TimeEmbedding(nn.Module):
|
7 |
+
def __init__(self, n_embd):
|
8 |
+
super().__init__()
|
9 |
+
self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
|
10 |
+
self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
x = self.linear_1(x)
|
14 |
+
x = F.silu(x)
|
15 |
+
x = self.linear_2(x)
|
16 |
+
|
17 |
+
return x
|
18 |
+
|
19 |
+
class UNET_ResidualBlock(nn.Module):
|
20 |
+
def __init__(self, in_channels, out_channels, n_time=1280):
|
21 |
+
super().__init__()
|
22 |
+
self.groupnorm_feature = nn.GroupNorm(32, in_channels)
|
23 |
+
self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
24 |
+
self.linear_time = nn.Linear(n_time, out_channels)
|
25 |
+
|
26 |
+
self.groupnorm_merged = nn.GroupNorm(32, out_channels)
|
27 |
+
self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
28 |
+
|
29 |
+
if in_channels == out_channels:
|
30 |
+
self.residual_layer = nn.Identity()
|
31 |
+
else:
|
32 |
+
self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
33 |
+
|
34 |
+
def forward(self, feature, time):
|
35 |
+
residue = feature
|
36 |
+
|
37 |
+
feature = self.groupnorm_feature(feature)
|
38 |
+
feature = F.silu(feature)
|
39 |
+
feature = self.conv_feature(feature)
|
40 |
+
|
41 |
+
time = F.silu(time)
|
42 |
+
|
43 |
+
time = self.linear_time(time)
|
44 |
+
merged = feature + time.unsqueeze(-1).unsqueeze(-1)
|
45 |
+
merged = self.groupnorm_merged(merged)
|
46 |
+
merged = F.silu(merged)
|
47 |
+
merged = self.conv_merged(merged)
|
48 |
+
|
49 |
+
return merged + self.residual_layer(residue)
|
50 |
+
|
51 |
+
class UNET_AttentionBlock(nn.Module):
|
52 |
+
def __init__(self, n_head: int, n_embd: int, d_context=768):
|
53 |
+
super().__init__()
|
54 |
+
channels = n_head * n_embd
|
55 |
+
|
56 |
+
self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
|
57 |
+
self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
58 |
+
|
59 |
+
self.layernorm_1 = nn.LayerNorm(channels)
|
60 |
+
self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
|
61 |
+
self.layernorm_2 = nn.LayerNorm(channels)
|
62 |
+
self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
|
63 |
+
self.layernorm_3 = nn.LayerNorm(channels)
|
64 |
+
self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
|
65 |
+
self.linear_geglu_2 = nn.Linear(4 * channels, channels)
|
66 |
+
|
67 |
+
self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
68 |
+
|
69 |
+
def forward(self, x, context):
|
70 |
+
residue_long = x
|
71 |
+
|
72 |
+
x = self.groupnorm(x)
|
73 |
+
x = self.conv_input(x)
|
74 |
+
|
75 |
+
n, c, h, w = x.shape
|
76 |
+
x = x.view((n, c, h * w))
|
77 |
+
|
78 |
+
x = x.transpose(-1, -2)
|
79 |
+
|
80 |
+
residue_short = x
|
81 |
+
|
82 |
+
x = self.layernorm_1(x)
|
83 |
+
x = self.attention_1(x)
|
84 |
+
x += residue_short
|
85 |
+
|
86 |
+
residue_short = x
|
87 |
+
|
88 |
+
x = self.layernorm_2(x)
|
89 |
+
x = self.attention_2(x, context)
|
90 |
+
|
91 |
+
x += residue_short
|
92 |
+
|
93 |
+
residue_short = x
|
94 |
+
|
95 |
+
x = self.layernorm_3(x)
|
96 |
+
|
97 |
+
# GeGLU as implemented in the original code: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/attention.py#L37C10-L37C10
|
98 |
+
x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
|
99 |
+
x = x * F.gelu(gate)
|
100 |
+
x = self.linear_geglu_2(x)
|
101 |
+
x += residue_short
|
102 |
+
x = x.transpose(-1, -2)
|
103 |
+
|
104 |
+
x = x.view((n, c, h, w))
|
105 |
+
|
106 |
+
return self.conv_output(x) + residue_long
|
107 |
+
|
108 |
+
class Upsample(nn.Module):
|
109 |
+
def __init__(self, channels):
|
110 |
+
super().__init__()
|
111 |
+
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
115 |
+
return self.conv(x)
|
116 |
+
|
117 |
+
class SwitchSequential(nn.Sequential):
|
118 |
+
def forward(self, x, context, time):
|
119 |
+
for layer in self:
|
120 |
+
if isinstance(layer, UNET_AttentionBlock):
|
121 |
+
x = layer(x, context)
|
122 |
+
elif isinstance(layer, UNET_ResidualBlock):
|
123 |
+
x = layer(x, time)
|
124 |
+
else:
|
125 |
+
x = layer(x)
|
126 |
+
return x
|
127 |
+
|
128 |
+
class UNET(nn.Module):
|
129 |
+
def __init__(self):
|
130 |
+
super().__init__()
|
131 |
+
self.encoders = nn.ModuleList([
|
132 |
+
SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
|
133 |
+
SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
|
134 |
+
SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
|
135 |
+
SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
|
136 |
+
SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
|
137 |
+
SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
|
138 |
+
SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
|
139 |
+
SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
|
140 |
+
SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
|
141 |
+
SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
|
142 |
+
SwitchSequential(UNET_ResidualBlock(1280, 1280)),
|
143 |
+
SwitchSequential(UNET_ResidualBlock(1280, 1280)),
|
144 |
+
])
|
145 |
+
|
146 |
+
self.bottleneck = SwitchSequential(
|
147 |
+
UNET_ResidualBlock(1280, 1280),
|
148 |
+
UNET_AttentionBlock(8, 160),
|
149 |
+
UNET_ResidualBlock(1280, 1280),
|
150 |
+
)
|
151 |
+
|
152 |
+
self.decoders = nn.ModuleList([
|
153 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280)),
|
154 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280)),
|
155 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
|
156 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
|
157 |
+
|
158 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
|
159 |
+
SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
|
160 |
+
|
161 |
+
SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
|
162 |
+
SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
|
163 |
+
|
164 |
+
SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
|
165 |
+
SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
|
166 |
+
SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
|
167 |
+
|
168 |
+
SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
|
169 |
+
])
|
170 |
+
|
171 |
+
def forward(self, x, context, time):
|
172 |
+
skip_connections = []
|
173 |
+
for layers in self.encoders:
|
174 |
+
x = layers(x, context, time)
|
175 |
+
skip_connections.append(x)
|
176 |
+
|
177 |
+
x = self.bottleneck(x, context, time)
|
178 |
+
|
179 |
+
for layers in self.decoders:
|
180 |
+
x = torch.cat((x, skip_connections.pop()), dim=1)
|
181 |
+
x = layers(x, context, time)
|
182 |
+
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
class UNET_OutputLayer(nn.Module):
|
187 |
+
def __init__(self, in_channels, out_channels):
|
188 |
+
super().__init__()
|
189 |
+
self.groupnorm = nn.GroupNorm(32, in_channels)
|
190 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
191 |
+
|
192 |
+
def forward(self, x):
|
193 |
+
x = self.groupnorm(x)
|
194 |
+
x = F.silu(x)
|
195 |
+
x = self.conv(x)
|
196 |
+
|
197 |
+
return x
|
198 |
+
|
199 |
+
class Diffusion(nn.Module):
|
200 |
+
def __init__(self):
|
201 |
+
super().__init__()
|
202 |
+
self.time_embedding = TimeEmbedding(320)
|
203 |
+
self.unet = UNET()
|
204 |
+
self.final = UNET_OutputLayer(320, 4)
|
205 |
+
|
206 |
+
def forward(self, latent, context, time):
|
207 |
+
time = self.time_embedding(time)
|
208 |
+
|
209 |
+
output = self.unet(latent, context, time)
|
210 |
+
|
211 |
+
output = self.final(output)
|
212 |
+
|
213 |
+
return output
|
encoder.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from nanograd.models.stable_diffusion.decoder import VAE_AttentionBlock, VAE_ResidualBlock
|
5 |
+
|
6 |
+
class VAE_Encoder(nn.Sequential):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__(
|
9 |
+
nn.Conv2d(3, 128, kernel_size=3, padding=1),
|
10 |
+
|
11 |
+
VAE_ResidualBlock(128, 128),
|
12 |
+
VAE_ResidualBlock(128, 128),
|
13 |
+
|
14 |
+
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
|
15 |
+
|
16 |
+
VAE_ResidualBlock(128, 256),
|
17 |
+
VAE_ResidualBlock(256, 256),
|
18 |
+
|
19 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
|
20 |
+
|
21 |
+
VAE_ResidualBlock(256, 512),
|
22 |
+
VAE_ResidualBlock(512, 512),
|
23 |
+
|
24 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
|
25 |
+
|
26 |
+
VAE_ResidualBlock(512, 512),
|
27 |
+
VAE_ResidualBlock(512, 512),
|
28 |
+
VAE_ResidualBlock(512, 512),
|
29 |
+
VAE_AttentionBlock(512),
|
30 |
+
VAE_ResidualBlock(512, 512),
|
31 |
+
|
32 |
+
nn.GroupNorm(32, 512),
|
33 |
+
|
34 |
+
nn.SiLU(),
|
35 |
+
|
36 |
+
nn.Conv2d(512, 8, kernel_size=3, padding=1),
|
37 |
+
|
38 |
+
nn.Conv2d(8, 8, kernel_size=1, padding=0),
|
39 |
+
)
|
40 |
+
|
41 |
+
def forward(self, x, noise):
|
42 |
+
for module in self:
|
43 |
+
|
44 |
+
if getattr(module, 'stride', None) == (2, 2):
|
45 |
+
x = F.pad(x, (0, 1, 0, 1))
|
46 |
+
|
47 |
+
x = module(x)
|
48 |
+
mean, log_variance = torch.chunk(x, 2, dim=1)
|
49 |
+
log_variance = torch.clamp(log_variance, -30, 20)
|
50 |
+
variance = log_variance.exp()
|
51 |
+
stdev = variance.sqrt()
|
52 |
+
x = mean + stdev * noise
|
53 |
+
# Constant taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L17C1-L17C1
|
54 |
+
x *= 0.18215
|
55 |
+
|
56 |
+
return x
|
gaussing_diffusion.py
ADDED
@@ -0,0 +1,814 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import enum
|
5 |
+
|
6 |
+
class GaussingDistribution:
|
7 |
+
def __init__(self, parameters: torch.Tensor) -> None:
|
8 |
+
self.mean, log_variance = torch.chunk(parameters, 2, dim=1)
|
9 |
+
self.log_variance = torch.clamp(log_variance, -30.0, 20.0)
|
10 |
+
self.std = torch.exp(0.5 * self.log_variance)
|
11 |
+
|
12 |
+
def sample(self):
|
13 |
+
return self.mean + self.std * torch.rand_like(self.std)
|
14 |
+
|
15 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
16 |
+
tensor = None
|
17 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
18 |
+
if isinstance(obj, torch.Tensor):
|
19 |
+
tensor = obj
|
20 |
+
break
|
21 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
22 |
+
|
23 |
+
logvar1, logvar2 = [
|
24 |
+
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
25 |
+
for x in (logvar1, logvar2)
|
26 |
+
]
|
27 |
+
|
28 |
+
return 0.5 * (
|
29 |
+
-1.0
|
30 |
+
+ logvar2
|
31 |
+
- logvar1
|
32 |
+
+ torch.exp(logvar1 - logvar2)
|
33 |
+
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def approx_standard_normal_cdf(x):
|
38 |
+
|
39 |
+
return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
40 |
+
|
41 |
+
|
42 |
+
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
43 |
+
|
44 |
+
centered_x = x - means
|
45 |
+
inv_stdv = torch.exp(-log_scales)
|
46 |
+
normalized_x = centered_x * inv_stdv
|
47 |
+
log_probs = torch.distributions.Normal(torch.zeros_like(x), torch.ones_like(x)).log_prob(normalized_x)
|
48 |
+
return log_probs
|
49 |
+
|
50 |
+
|
51 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
52 |
+
|
53 |
+
assert x.shape == means.shape == log_scales.shape
|
54 |
+
centered_x = x - means
|
55 |
+
inv_stdv = torch.exp(-log_scales)
|
56 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
57 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
58 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
59 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
60 |
+
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
|
61 |
+
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
|
62 |
+
cdf_delta = cdf_plus - cdf_min
|
63 |
+
log_probs = torch.where(
|
64 |
+
x < -0.999,
|
65 |
+
log_cdf_plus,
|
66 |
+
torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
|
67 |
+
)
|
68 |
+
assert log_probs.shape == x.shape
|
69 |
+
return log_probs
|
70 |
+
|
71 |
+
################# Gaussing ####################
|
72 |
+
|
73 |
+
def mean_flat(tensor):
|
74 |
+
|
75 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
76 |
+
|
77 |
+
|
78 |
+
class ModelMeanType(enum.Enum):
|
79 |
+
|
80 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
81 |
+
START_X = enum.auto() # the model predicts x_0
|
82 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
83 |
+
|
84 |
+
|
85 |
+
class ModelVarType(enum.Enum):
|
86 |
+
|
87 |
+
LEARNED = enum.auto()
|
88 |
+
FIXED_SMALL = enum.auto()
|
89 |
+
FIXED_LARGE = enum.auto()
|
90 |
+
LEARNED_RANGE = enum.auto()
|
91 |
+
|
92 |
+
|
93 |
+
class LossType(enum.Enum):
|
94 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
95 |
+
RESCALED_MSE = (
|
96 |
+
enum.auto()
|
97 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
98 |
+
KL = enum.auto() # use the variational lower-bound
|
99 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
100 |
+
|
101 |
+
def is_vb(self):
|
102 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
103 |
+
|
104 |
+
|
105 |
+
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
106 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
107 |
+
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
108 |
+
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
109 |
+
return betas
|
110 |
+
|
111 |
+
|
112 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
113 |
+
if beta_schedule == "quad":
|
114 |
+
betas = (
|
115 |
+
np.linspace(
|
116 |
+
beta_start ** 0.5,
|
117 |
+
beta_end ** 0.5,
|
118 |
+
num_diffusion_timesteps,
|
119 |
+
dtype=np.float64,
|
120 |
+
)
|
121 |
+
** 2
|
122 |
+
)
|
123 |
+
elif beta_schedule == "linear":
|
124 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
125 |
+
elif beta_schedule == "warmup10":
|
126 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
127 |
+
elif beta_schedule == "warmup50":
|
128 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
129 |
+
elif beta_schedule == "const":
|
130 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
131 |
+
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
132 |
+
betas = 1.0 / np.linspace(
|
133 |
+
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
|
134 |
+
)
|
135 |
+
else:
|
136 |
+
raise NotImplementedError(beta_schedule)
|
137 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
138 |
+
return betas
|
139 |
+
|
140 |
+
|
141 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
142 |
+
|
143 |
+
if schedule_name == "linear":
|
144 |
+
|
145 |
+
scale = 1000 / num_diffusion_timesteps
|
146 |
+
return get_beta_schedule(
|
147 |
+
"linear",
|
148 |
+
beta_start=scale * 0.0001,
|
149 |
+
beta_end=scale * 0.02,
|
150 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
151 |
+
)
|
152 |
+
elif schedule_name == "squaredcos_cap_v2":
|
153 |
+
return betas_for_alpha_bar(
|
154 |
+
num_diffusion_timesteps,
|
155 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
156 |
+
)
|
157 |
+
else:
|
158 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
159 |
+
|
160 |
+
|
161 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
162 |
+
|
163 |
+
betas = []
|
164 |
+
for i in range(num_diffusion_timesteps):
|
165 |
+
t1 = i / num_diffusion_timesteps
|
166 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
167 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
168 |
+
return np.array(betas)
|
169 |
+
|
170 |
+
|
171 |
+
class GaussianDiffusion:
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
*,
|
175 |
+
betas,
|
176 |
+
model_mean_type,
|
177 |
+
model_var_type,
|
178 |
+
loss_type
|
179 |
+
):
|
180 |
+
|
181 |
+
self.model_mean_type = model_mean_type
|
182 |
+
self.model_var_type = model_var_type
|
183 |
+
self.loss_type = loss_type
|
184 |
+
|
185 |
+
# Use float64 for accuracy.
|
186 |
+
betas = np.array(betas, dtype=np.float64)
|
187 |
+
self.betas = betas
|
188 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
189 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
190 |
+
|
191 |
+
self.num_timesteps = int(betas.shape[0])
|
192 |
+
|
193 |
+
alphas = 1.0 - betas
|
194 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
195 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
196 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
197 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
198 |
+
|
199 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
200 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
201 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
202 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
203 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
204 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
205 |
+
|
206 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
207 |
+
self.posterior_variance = (
|
208 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
209 |
+
)
|
210 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
211 |
+
self.posterior_log_variance_clipped = np.log(
|
212 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
213 |
+
) if len(self.posterior_variance) > 1 else np.array([])
|
214 |
+
|
215 |
+
self.posterior_mean_coef1 = (
|
216 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
217 |
+
)
|
218 |
+
self.posterior_mean_coef2 = (
|
219 |
+
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
220 |
+
)
|
221 |
+
|
222 |
+
def q_mean_variance(self, x_start, t):
|
223 |
+
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
224 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
225 |
+
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
226 |
+
return mean, variance, log_variance
|
227 |
+
|
228 |
+
def q_sample(self, x_start, t, noise=None):
|
229 |
+
if noise is None:
|
230 |
+
noise = torch.randn_like(x_start)
|
231 |
+
assert noise.shape == x_start.shape
|
232 |
+
return (
|
233 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
234 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
235 |
+
)
|
236 |
+
|
237 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
238 |
+
assert x_start.shape == x_t.shape
|
239 |
+
posterior_mean = (
|
240 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
241 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
242 |
+
)
|
243 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
244 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
245 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
246 |
+
)
|
247 |
+
assert (
|
248 |
+
posterior_mean.shape[0]
|
249 |
+
== posterior_variance.shape[0]
|
250 |
+
== posterior_log_variance_clipped.shape[0]
|
251 |
+
== x_start.shape[0]
|
252 |
+
)
|
253 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
254 |
+
|
255 |
+
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
256 |
+
if model_kwargs is None:
|
257 |
+
model_kwargs = {}
|
258 |
+
|
259 |
+
B, C = x.shape[:2]
|
260 |
+
assert t.shape == (B,)
|
261 |
+
model_output = model(x, t, **model_kwargs)
|
262 |
+
if isinstance(model_output, tuple):
|
263 |
+
model_output, extra = model_output
|
264 |
+
else:
|
265 |
+
extra = None
|
266 |
+
|
267 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
268 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
269 |
+
model_output, model_var_values = torch.split(model_output, C, dim=1)
|
270 |
+
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
271 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
272 |
+
frac = (model_var_values + 1) / 2
|
273 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
274 |
+
model_variance = torch.exp(model_log_variance)
|
275 |
+
else:
|
276 |
+
model_variance, model_log_variance = {
|
277 |
+
ModelVarType.FIXED_LARGE: (
|
278 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
279 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
280 |
+
),
|
281 |
+
ModelVarType.FIXED_SMALL: (
|
282 |
+
self.posterior_variance,
|
283 |
+
self.posterior_log_variance_clipped,
|
284 |
+
),
|
285 |
+
}[self.model_var_type]
|
286 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
287 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
288 |
+
|
289 |
+
def process_xstart(x):
|
290 |
+
if denoised_fn is not None:
|
291 |
+
x = denoised_fn(x)
|
292 |
+
if clip_denoised:
|
293 |
+
return x.clamp(-1, 1)
|
294 |
+
return x
|
295 |
+
|
296 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
297 |
+
pred_xstart = process_xstart(model_output)
|
298 |
+
else:
|
299 |
+
pred_xstart = process_xstart(
|
300 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
301 |
+
)
|
302 |
+
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
303 |
+
|
304 |
+
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
305 |
+
return {
|
306 |
+
"mean": model_mean,
|
307 |
+
"variance": model_variance,
|
308 |
+
"log_variance": model_log_variance,
|
309 |
+
"pred_xstart": pred_xstart,
|
310 |
+
"extra": extra,
|
311 |
+
}
|
312 |
+
|
313 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
314 |
+
assert x_t.shape == eps.shape
|
315 |
+
return (
|
316 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
317 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
318 |
+
)
|
319 |
+
|
320 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
321 |
+
return (
|
322 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
323 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
324 |
+
|
325 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
326 |
+
gradient = cond_fn(x, t, **model_kwargs)
|
327 |
+
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
328 |
+
return new_mean
|
329 |
+
|
330 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
331 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
332 |
+
|
333 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
334 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
335 |
+
|
336 |
+
out = p_mean_var.copy()
|
337 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
338 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
339 |
+
return out
|
340 |
+
|
341 |
+
def p_sample(
|
342 |
+
self,
|
343 |
+
model,
|
344 |
+
x,
|
345 |
+
t,
|
346 |
+
clip_denoised=True,
|
347 |
+
denoised_fn=None,
|
348 |
+
cond_fn=None,
|
349 |
+
model_kwargs=None,
|
350 |
+
):
|
351 |
+
out = self.p_mean_variance(
|
352 |
+
model,
|
353 |
+
x,
|
354 |
+
t,
|
355 |
+
clip_denoised=clip_denoised,
|
356 |
+
denoised_fn=denoised_fn,
|
357 |
+
model_kwargs=model_kwargs,
|
358 |
+
)
|
359 |
+
noise = torch.randn_like(x)
|
360 |
+
nonzero_mask = (
|
361 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
362 |
+
) # no noise when t == 0
|
363 |
+
if cond_fn is not None:
|
364 |
+
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
365 |
+
sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise
|
366 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
367 |
+
|
368 |
+
def p_sample_loop(
|
369 |
+
self,
|
370 |
+
model,
|
371 |
+
shape,
|
372 |
+
noise=None,
|
373 |
+
clip_denoised=True,
|
374 |
+
denoised_fn=None,
|
375 |
+
cond_fn=None,
|
376 |
+
model_kwargs=None,
|
377 |
+
device=None,
|
378 |
+
progress=False,
|
379 |
+
):
|
380 |
+
final = None
|
381 |
+
for sample in self.p_sample_loop_progressive(
|
382 |
+
model,
|
383 |
+
shape,
|
384 |
+
noise=noise,
|
385 |
+
clip_denoised=clip_denoised,
|
386 |
+
denoised_fn=denoised_fn,
|
387 |
+
cond_fn=cond_fn,
|
388 |
+
model_kwargs=model_kwargs,
|
389 |
+
device=device,
|
390 |
+
progress=progress,
|
391 |
+
):
|
392 |
+
final = sample
|
393 |
+
return final["sample"]
|
394 |
+
|
395 |
+
def p_sample_loop_progressive(
|
396 |
+
self,
|
397 |
+
model,
|
398 |
+
shape,
|
399 |
+
noise=None,
|
400 |
+
clip_denoised=True,
|
401 |
+
denoised_fn=None,
|
402 |
+
cond_fn=None,
|
403 |
+
model_kwargs=None,
|
404 |
+
device=None,
|
405 |
+
progress=False,
|
406 |
+
):
|
407 |
+
if device is None:
|
408 |
+
device = next(model.parameters()).device
|
409 |
+
assert isinstance(shape, (tuple, list))
|
410 |
+
if noise is not None:
|
411 |
+
img = noise
|
412 |
+
else:
|
413 |
+
img = torch.randn(*shape, device=device)
|
414 |
+
indices = list(range(self.num_timesteps))[::-1]
|
415 |
+
|
416 |
+
if progress:
|
417 |
+
# Lazy import so that we don't depend on tqdm.
|
418 |
+
from tqdm.auto import tqdm
|
419 |
+
|
420 |
+
indices = tqdm(indices)
|
421 |
+
|
422 |
+
for i in indices:
|
423 |
+
t = torch.tensor([i] * shape[0], device=device)
|
424 |
+
with torch.no_grad():
|
425 |
+
out = self.p_sample(
|
426 |
+
model,
|
427 |
+
img,
|
428 |
+
t,
|
429 |
+
clip_denoised=clip_denoised,
|
430 |
+
denoised_fn=denoised_fn,
|
431 |
+
cond_fn=cond_fn,
|
432 |
+
model_kwargs=model_kwargs,
|
433 |
+
)
|
434 |
+
yield out
|
435 |
+
img = out["sample"]
|
436 |
+
|
437 |
+
def ddim_sample(
|
438 |
+
self,
|
439 |
+
model,
|
440 |
+
x,
|
441 |
+
t,
|
442 |
+
clip_denoised=True,
|
443 |
+
denoised_fn=None,
|
444 |
+
cond_fn=None,
|
445 |
+
model_kwargs=None,
|
446 |
+
eta=0.0,
|
447 |
+
):
|
448 |
+
out = self.p_mean_variance(
|
449 |
+
model,
|
450 |
+
x,
|
451 |
+
t,
|
452 |
+
clip_denoised=clip_denoised,
|
453 |
+
denoised_fn=denoised_fn,
|
454 |
+
model_kwargs=model_kwargs,
|
455 |
+
)
|
456 |
+
if cond_fn is not None:
|
457 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
458 |
+
|
459 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
460 |
+
|
461 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
462 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
463 |
+
sigma = (
|
464 |
+
eta
|
465 |
+
* torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
466 |
+
* torch.sqrt(1 - alpha_bar / alpha_bar_prev)
|
467 |
+
)
|
468 |
+
# Equation 12.
|
469 |
+
noise = torch.randn_like(x)
|
470 |
+
mean_pred = (
|
471 |
+
out["pred_xstart"] * torch.sqrt(alpha_bar_prev)
|
472 |
+
+ torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
473 |
+
)
|
474 |
+
nonzero_mask = (
|
475 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
476 |
+
) # no noise when t == 0
|
477 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
478 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
479 |
+
|
480 |
+
def ddim_reverse_sample(
|
481 |
+
self,
|
482 |
+
model,
|
483 |
+
x,
|
484 |
+
t,
|
485 |
+
clip_denoised=True,
|
486 |
+
denoised_fn=None,
|
487 |
+
cond_fn=None,
|
488 |
+
model_kwargs=None,
|
489 |
+
eta=0.0,
|
490 |
+
):
|
491 |
+
|
492 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
493 |
+
out = self.p_mean_variance(
|
494 |
+
model,
|
495 |
+
x,
|
496 |
+
t,
|
497 |
+
clip_denoised=clip_denoised,
|
498 |
+
denoised_fn=denoised_fn,
|
499 |
+
model_kwargs=model_kwargs,
|
500 |
+
)
|
501 |
+
if cond_fn is not None:
|
502 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
503 |
+
eps = (
|
504 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
505 |
+
- out["pred_xstart"]
|
506 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
507 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
508 |
+
|
509 |
+
# Equation 12. reversed
|
510 |
+
mean_pred = out["pred_xstart"] * torch.sqrt(alpha_bar_next) + torch.sqrt(1 - alpha_bar_next) * eps
|
511 |
+
|
512 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
513 |
+
|
514 |
+
def ddim_sample_loop(
|
515 |
+
self,
|
516 |
+
model,
|
517 |
+
shape,
|
518 |
+
noise=None,
|
519 |
+
clip_denoised=True,
|
520 |
+
denoised_fn=None,
|
521 |
+
cond_fn=None,
|
522 |
+
model_kwargs=None,
|
523 |
+
device=None,
|
524 |
+
progress=False,
|
525 |
+
eta=0.0,
|
526 |
+
):
|
527 |
+
final = None
|
528 |
+
for sample in self.ddim_sample_loop_progressive(
|
529 |
+
model,
|
530 |
+
shape,
|
531 |
+
noise=noise,
|
532 |
+
clip_denoised=clip_denoised,
|
533 |
+
denoised_fn=denoised_fn,
|
534 |
+
cond_fn=cond_fn,
|
535 |
+
model_kwargs=model_kwargs,
|
536 |
+
device=device,
|
537 |
+
progress=progress,
|
538 |
+
eta=eta,
|
539 |
+
):
|
540 |
+
final = sample
|
541 |
+
return final["sample"]
|
542 |
+
|
543 |
+
def ddim_sample_loop_progressive(
|
544 |
+
self,
|
545 |
+
model,
|
546 |
+
shape,
|
547 |
+
noise=None,
|
548 |
+
clip_denoised=True,
|
549 |
+
denoised_fn=None,
|
550 |
+
cond_fn=None,
|
551 |
+
model_kwargs=None,
|
552 |
+
device=None,
|
553 |
+
progress=False,
|
554 |
+
eta=0.0,
|
555 |
+
):
|
556 |
+
if device is None:
|
557 |
+
device = next(model.parameters()).device
|
558 |
+
assert isinstance(shape, (tuple, list))
|
559 |
+
if noise is not None:
|
560 |
+
img = noise
|
561 |
+
else:
|
562 |
+
img = torch.randn(*shape, device=device)
|
563 |
+
indices = list(range(self.num_timesteps))[::-1]
|
564 |
+
|
565 |
+
if progress:
|
566 |
+
# Lazy import so that we don't depend on tqdm.
|
567 |
+
from tqdm.auto import tqdm
|
568 |
+
|
569 |
+
indices = tqdm(indices)
|
570 |
+
|
571 |
+
for i in indices:
|
572 |
+
t = torch.tensor([i] * shape[0], device=device)
|
573 |
+
with torch.no_grad():
|
574 |
+
out = self.ddim_sample(
|
575 |
+
model,
|
576 |
+
img,
|
577 |
+
t,
|
578 |
+
clip_denoised=clip_denoised,
|
579 |
+
denoised_fn=denoised_fn,
|
580 |
+
cond_fn=cond_fn,
|
581 |
+
model_kwargs=model_kwargs,
|
582 |
+
eta=eta,
|
583 |
+
)
|
584 |
+
yield out
|
585 |
+
img = out["sample"]
|
586 |
+
|
587 |
+
def _vb_terms_bpd(
|
588 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
589 |
+
):
|
590 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
591 |
+
x_start=x_start, x_t=x_t, t=t
|
592 |
+
)
|
593 |
+
out = self.p_mean_variance(
|
594 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
595 |
+
)
|
596 |
+
kl = normal_kl(
|
597 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
598 |
+
)
|
599 |
+
kl = mean_flat(kl) / np.log(2.0)
|
600 |
+
|
601 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
602 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
603 |
+
)
|
604 |
+
assert decoder_nll.shape == x_start.shape
|
605 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
606 |
+
|
607 |
+
# At the first timestep return the decoder NLL,
|
608 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
609 |
+
output = torch.where((t == 0), decoder_nll, kl)
|
610 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
611 |
+
|
612 |
+
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
613 |
+
|
614 |
+
if model_kwargs is None:
|
615 |
+
model_kwargs = {}
|
616 |
+
if noise is None:
|
617 |
+
noise = torch.randn_like(x_start)
|
618 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
619 |
+
|
620 |
+
terms = {}
|
621 |
+
|
622 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
623 |
+
terms["loss"] = self._vb_terms_bpd(
|
624 |
+
model=model,
|
625 |
+
x_start=x_start,
|
626 |
+
x_t=x_t,
|
627 |
+
t=t,
|
628 |
+
clip_denoised=False,
|
629 |
+
model_kwargs=model_kwargs,
|
630 |
+
)["output"]
|
631 |
+
if self.loss_type == LossType.RESCALED_KL:
|
632 |
+
terms["loss"] *= self.num_timesteps
|
633 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
634 |
+
model_output = model(x_t, t, **model_kwargs)
|
635 |
+
|
636 |
+
if self.model_var_type in [
|
637 |
+
ModelVarType.LEARNED,
|
638 |
+
ModelVarType.LEARNED_RANGE,
|
639 |
+
]:
|
640 |
+
B, C = x_t.shape[:2]
|
641 |
+
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
|
642 |
+
model_output, model_var_values = torch.split(model_output, C, dim=1)
|
643 |
+
|
644 |
+
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
|
645 |
+
terms["vb"] = self._vb_terms_bpd(
|
646 |
+
model=lambda *args, r=frozen_out: r,
|
647 |
+
x_start=x_start,
|
648 |
+
x_t=x_t,
|
649 |
+
t=t,
|
650 |
+
clip_denoised=False,
|
651 |
+
)["output"]
|
652 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
653 |
+
|
654 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
655 |
+
|
656 |
+
target = {
|
657 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
658 |
+
x_start=x_start, x_t=x_t, t=t
|
659 |
+
)[0],
|
660 |
+
ModelMeanType.START_X: x_start,
|
661 |
+
ModelMeanType.EPSILON: noise,
|
662 |
+
}[self.model_mean_type]
|
663 |
+
assert model_output.shape == target.shape == x_start.shape
|
664 |
+
terms["mse"] = mean_flat((target - model_output) ** 2)
|
665 |
+
if "vb" in terms:
|
666 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
667 |
+
else:
|
668 |
+
terms["loss"] = terms["mse"]
|
669 |
+
else:
|
670 |
+
raise NotImplementedError(self.loss_type)
|
671 |
+
|
672 |
+
return terms
|
673 |
+
|
674 |
+
def _prior_bpd(self, x_start):
|
675 |
+
batch_size = x_start.shape[0]
|
676 |
+
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
677 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
678 |
+
kl_prior = normal_kl(
|
679 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
680 |
+
)
|
681 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
682 |
+
|
683 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
684 |
+
|
685 |
+
device = x_start.device
|
686 |
+
batch_size = x_start.shape[0]
|
687 |
+
|
688 |
+
vb = []
|
689 |
+
xstart_mse = []
|
690 |
+
mse = []
|
691 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
692 |
+
t_batch = torch.tensor([t] * batch_size, device=device)
|
693 |
+
noise = torch.randn_like(x_start)
|
694 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
695 |
+
# Calculate VLB term at the current timestep
|
696 |
+
with torch.no_grad():
|
697 |
+
out = self._vb_terms_bpd(
|
698 |
+
model,
|
699 |
+
x_start=x_start,
|
700 |
+
x_t=x_t,
|
701 |
+
t=t_batch,
|
702 |
+
clip_denoised=clip_denoised,
|
703 |
+
model_kwargs=model_kwargs,
|
704 |
+
)
|
705 |
+
vb.append(out["output"])
|
706 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
707 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
708 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
709 |
+
|
710 |
+
vb = torch.stack(vb, dim=1)
|
711 |
+
xstart_mse = torch.stack(xstart_mse, dim=1)
|
712 |
+
mse = torch.stack(mse, dim=1)
|
713 |
+
|
714 |
+
prior_bpd = self._prior_bpd(x_start)
|
715 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
716 |
+
return {
|
717 |
+
"total_bpd": total_bpd,
|
718 |
+
"prior_bpd": prior_bpd,
|
719 |
+
"vb": vb,
|
720 |
+
"xstart_mse": xstart_mse,
|
721 |
+
"mse": mse,
|
722 |
+
}
|
723 |
+
|
724 |
+
|
725 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
726 |
+
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
727 |
+
while len(res.shape) < len(broadcast_shape):
|
728 |
+
res = res[..., None]
|
729 |
+
return res + torch.zeros(broadcast_shape, device=timesteps.device)
|
730 |
+
|
731 |
+
############################### Denoising Diffusion Probabilistic Model###################################
|
732 |
+
class DDPMSampler:
|
733 |
+
def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
|
734 |
+
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
|
735 |
+
self.alphas = 1.0 - self.betas
|
736 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, d_model=0)
|
737 |
+
self.one = torch.tensor(1.0)
|
738 |
+
self.generator = generator
|
739 |
+
self.num_train_timesteps = num_training_steps
|
740 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
|
741 |
+
|
742 |
+
def set_inference_timesteps(self, num_inference_steps=50):
|
743 |
+
self.num_inference_steps = num_inference_steps
|
744 |
+
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
745 |
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
746 |
+
self.timesteps = torch.from_numpy(timesteps)
|
747 |
+
|
748 |
+
def _get_previous_timestep(self, timestep: int) -> int:
|
749 |
+
prev_t = timestep - self.num_train_timesteps // self.num_inference_steps
|
750 |
+
return prev_t
|
751 |
+
|
752 |
+
def _get_variance(self, timestep: int) -> torch.Tensor:
|
753 |
+
prev_t = self._get_previous_timestep(timestep)
|
754 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
755 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
|
756 |
+
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
|
757 |
+
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
|
758 |
+
variance = torch.clamp(variance, min=1e-20)
|
759 |
+
return variance
|
760 |
+
|
761 |
+
def set_strength(self, strength=1):
|
762 |
+
"""
|
763 |
+
Set how much noise to add to the input image.
|
764 |
+
More noise (strength ~ 1) means that the output will be further from the input image.
|
765 |
+
Less noise (strength ~ 0) means that the output will be closer to the input image.
|
766 |
+
"""
|
767 |
+
|
768 |
+
start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
|
769 |
+
self.timesteps = self.timesteps[start_step:]
|
770 |
+
self.start_step = start_step
|
771 |
+
|
772 |
+
def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
|
773 |
+
t = timestep
|
774 |
+
prev_t = self._get_previous_timestep(t)
|
775 |
+
alpha_prod_t = self.alphas_cumprod[t]
|
776 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
|
777 |
+
beta_prod_t = 1 - alpha_prod_t
|
778 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
779 |
+
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
|
780 |
+
current_beta_t = 1 - current_alpha_t
|
781 |
+
pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
782 |
+
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
|
783 |
+
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
|
784 |
+
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
|
785 |
+
variance = 0
|
786 |
+
if t > 0:
|
787 |
+
device = model_output.device
|
788 |
+
noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
|
789 |
+
|
790 |
+
variance = (self._get_variance(t) ** 0.5) * noise
|
791 |
+
pred_prev_sample = pred_prev_sample + variance
|
792 |
+
|
793 |
+
return pred_prev_sample
|
794 |
+
|
795 |
+
def add_noise(
|
796 |
+
self,
|
797 |
+
original_samples: torch.FloatTensor,
|
798 |
+
timesteps: torch.IntTensor,
|
799 |
+
) -> torch.FloatTensor:
|
800 |
+
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
801 |
+
timesteps = timesteps.to(original_samples.device)
|
802 |
+
|
803 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
804 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
805 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
806 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
807 |
+
|
808 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
809 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
810 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
811 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
812 |
+
noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
|
813 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
814 |
+
return noisy_samples
|
model_converter.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model_loader.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from nanograd.models.stable_diffusion.clip import CLIP
|
2 |
+
from nanograd.models.stable_diffusion.encoder import VAE_Encoder
|
3 |
+
from nanograd.models.stable_diffusion.decoder import VAE_Decoder
|
4 |
+
from nanograd.models.stable_diffusion.diffusion import Diffusion
|
5 |
+
|
6 |
+
from nanograd.models.stable_diffusion import model_converter
|
7 |
+
|
8 |
+
def preload_models_from_standard_weights(ckpt_path, device):
|
9 |
+
state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
|
10 |
+
|
11 |
+
encoder = VAE_Encoder().to(device)
|
12 |
+
encoder.load_state_dict(state_dict['encoder'], strict=True)
|
13 |
+
|
14 |
+
decoder = VAE_Decoder().to(device)
|
15 |
+
decoder.load_state_dict(state_dict['decoder'], strict=True)
|
16 |
+
|
17 |
+
diffusion = Diffusion().to(device)
|
18 |
+
diffusion.load_state_dict(state_dict['diffusion'], strict=True)
|
19 |
+
|
20 |
+
clip = CLIP().to(device)
|
21 |
+
clip.load_state_dict(state_dict['clip'], strict=True)
|
22 |
+
|
23 |
+
return {
|
24 |
+
'clip': clip,
|
25 |
+
'encoder': encoder,
|
26 |
+
'decoder': decoder,
|
27 |
+
'diffusion': diffusion,
|
28 |
+
}
|
pipeline.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
from nanograd.models.stable_diffusion.ddpm import DDPMSampler
|
5 |
+
|
6 |
+
WIDTH = 512
|
7 |
+
HEIGHT = 512
|
8 |
+
LATENTS_WIDTH = WIDTH // 8
|
9 |
+
LATENTS_HEIGHT = HEIGHT // 8
|
10 |
+
|
11 |
+
def generate(
|
12 |
+
prompt,
|
13 |
+
uncond_prompt=None,
|
14 |
+
input_image=None,
|
15 |
+
strength=0.8,
|
16 |
+
do_cfg=True,
|
17 |
+
cfg_scale=7.5,
|
18 |
+
sampler_name="ddpm",
|
19 |
+
n_inference_steps=50,
|
20 |
+
models={},
|
21 |
+
seed=None,
|
22 |
+
device=None,
|
23 |
+
idle_device=None,
|
24 |
+
tokenizer=None,
|
25 |
+
):
|
26 |
+
with torch.no_grad():
|
27 |
+
if not 0 < strength <= 1:
|
28 |
+
raise ValueError("strength must be between 0 and 1")
|
29 |
+
|
30 |
+
if idle_device:
|
31 |
+
to_idle = lambda x: x.to(idle_device)
|
32 |
+
else:
|
33 |
+
to_idle = lambda x: x
|
34 |
+
|
35 |
+
generator = torch.Generator(device=device)
|
36 |
+
if seed is None:
|
37 |
+
generator.seed()
|
38 |
+
else:
|
39 |
+
generator.manual_seed(seed)
|
40 |
+
|
41 |
+
clip = models["clip"]
|
42 |
+
clip.to(device)
|
43 |
+
|
44 |
+
if do_cfg:
|
45 |
+
cond_tokens = tokenizer.batch_encode_plus(
|
46 |
+
[prompt], padding="max_length", max_length=77
|
47 |
+
).input_ids
|
48 |
+
cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
|
49 |
+
cond_context = clip(cond_tokens)
|
50 |
+
uncond_tokens = tokenizer.batch_encode_plus(
|
51 |
+
[uncond_prompt], padding="max_length", max_length=77
|
52 |
+
).input_ids
|
53 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
|
54 |
+
uncond_context = clip(uncond_tokens)
|
55 |
+
context = torch.cat([cond_context, uncond_context])
|
56 |
+
else:
|
57 |
+
tokens = tokenizer.batch_encode_plus(
|
58 |
+
[prompt], padding="max_length", max_length=77
|
59 |
+
).input_ids
|
60 |
+
tokens = torch.tensor(tokens, dtype=torch.long, device=device)
|
61 |
+
context = clip(tokens)
|
62 |
+
to_idle(clip)
|
63 |
+
|
64 |
+
if sampler_name == "ddpm":
|
65 |
+
sampler = DDPMSampler(generator)
|
66 |
+
sampler.set_inference_timesteps(n_inference_steps)
|
67 |
+
else:
|
68 |
+
raise ValueError("Unknown sampler value %s. ")
|
69 |
+
|
70 |
+
latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
|
71 |
+
|
72 |
+
if input_image:
|
73 |
+
encoder = models["encoder"]
|
74 |
+
encoder.to(device)
|
75 |
+
|
76 |
+
input_image_tensor = input_image.resize((WIDTH, HEIGHT))
|
77 |
+
input_image_tensor = np.array(input_image_tensor)
|
78 |
+
input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
|
79 |
+
input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
|
80 |
+
input_image_tensor = input_image_tensor.unsqueeze(0)
|
81 |
+
input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
|
82 |
+
|
83 |
+
encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
|
84 |
+
latents = encoder(input_image_tensor, encoder_noise)
|
85 |
+
|
86 |
+
sampler.set_strength(strength=strength)
|
87 |
+
latents = sampler.add_noise(latents, sampler.timesteps[0])
|
88 |
+
|
89 |
+
to_idle(encoder)
|
90 |
+
else:
|
91 |
+
latents = torch.randn(latents_shape, generator=generator, device=device)
|
92 |
+
|
93 |
+
diffusion = models["diffusion"]
|
94 |
+
diffusion.to(device)
|
95 |
+
|
96 |
+
timesteps = tqdm(sampler.timesteps)
|
97 |
+
for i, timestep in enumerate(timesteps):
|
98 |
+
time_embedding = get_time_embedding(timestep).to(device)
|
99 |
+
|
100 |
+
model_input = latents
|
101 |
+
|
102 |
+
if do_cfg:
|
103 |
+
model_input = model_input.repeat(2, 1, 1, 1)
|
104 |
+
|
105 |
+
model_output = diffusion(model_input, context, time_embedding)
|
106 |
+
|
107 |
+
if do_cfg:
|
108 |
+
output_cond, output_uncond = model_output.chunk(2)
|
109 |
+
model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
|
110 |
+
|
111 |
+
latents = sampler.step(timestep, latents, model_output)
|
112 |
+
|
113 |
+
to_idle(diffusion)
|
114 |
+
|
115 |
+
decoder = models["decoder"]
|
116 |
+
decoder.to(device)
|
117 |
+
images = decoder(latents)
|
118 |
+
to_idle(decoder)
|
119 |
+
|
120 |
+
images = rescale(images, (-1, 1), (0, 255), clamp=True)
|
121 |
+
images = images.permute(0, 2, 3, 1)
|
122 |
+
images = images.to("cpu", torch.uint8).numpy()
|
123 |
+
return images[0]
|
124 |
+
|
125 |
+
def rescale(x, old_range, new_range, clamp=False):
|
126 |
+
old_min, old_max = old_range
|
127 |
+
new_min, new_max = new_range
|
128 |
+
x -= old_min
|
129 |
+
x *= (new_max - new_min) / (old_max - old_min)
|
130 |
+
x += new_min
|
131 |
+
if clamp:
|
132 |
+
x = x.clamp(new_min, new_max)
|
133 |
+
return x
|
134 |
+
|
135 |
+
def get_time_embedding(timestep):
|
136 |
+
# Shape: (160,)
|
137 |
+
freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
|
138 |
+
# Shape: (1, 160)
|
139 |
+
x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
|
140 |
+
# Shape: (1, 160 * 2)
|
141 |
+
return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
|
sd_gradio.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
from pathlib import Path
|
4 |
+
from transformers import CLIPTokenizer
|
5 |
+
import torch
|
6 |
+
from nanograd.models.stable_diffusion import model_loader, pipeline
|
7 |
+
|
8 |
+
DEVICE = "cpu"
|
9 |
+
ALLOW_CUDA = False
|
10 |
+
ALLOW_MPS = False
|
11 |
+
|
12 |
+
if torch.cuda.is_available() and ALLOW_CUDA:
|
13 |
+
DEVICE = "cuda"
|
14 |
+
elif torch.backends.mps.is_available() and ALLOW_MPS:
|
15 |
+
DEVICE = "mps"
|
16 |
+
print(f"Using device: {DEVICE}")
|
17 |
+
|
18 |
+
tokenizer_vocab_path = Path("C:\\Users\\Esmail\\Desktop\\nanograd\\nanograd\\models\\stable_diffusion\\sd_data\\tokenizer_vocab.json")
|
19 |
+
tokenizer_merges_path = Path("C:\\Users\\Esmail\\Desktop\\nanograd\\nanograd\\models\\stable_diffusion\\sd_data\\tokenizer_merges.txt")
|
20 |
+
model_file = Path("C:\\Users\\Esmail\\Desktop\\nanograd\\nanograd\\models\\stable_diffusion\\sd_data\\v1-5-pruned-emaonly.ckpt")
|
21 |
+
|
22 |
+
tokenizer = CLIPTokenizer(str(tokenizer_vocab_path), merges_file=str(tokenizer_merges_path))
|
23 |
+
models = model_loader.preload_models_from_standard_weights(str(model_file), DEVICE)
|
24 |
+
|
25 |
+
def generate_image(prompt, cfg_scale, num_inference_steps, sampler):
|
26 |
+
uncond_prompt = ""
|
27 |
+
do_cfg = True
|
28 |
+
input_image = None
|
29 |
+
strength = 0.9
|
30 |
+
seed = 42
|
31 |
+
|
32 |
+
output_image = pipeline.generate(
|
33 |
+
prompt=prompt,
|
34 |
+
uncond_prompt=uncond_prompt,
|
35 |
+
input_image=input_image,
|
36 |
+
strength=strength,
|
37 |
+
do_cfg=do_cfg,
|
38 |
+
cfg_scale=cfg_scale,
|
39 |
+
sampler_name=sampler,
|
40 |
+
n_inference_steps=num_inference_steps,
|
41 |
+
seed=seed,
|
42 |
+
models=models,
|
43 |
+
device=DEVICE,
|
44 |
+
idle_device="cpu",
|
45 |
+
tokenizer=tokenizer,
|
46 |
+
)
|
47 |
+
|
48 |
+
output_image = Image.fromarray(output_image)
|
49 |
+
return output_image
|
50 |
+
|
51 |
+
# Gradio interface
|
52 |
+
def gradio_interface():
|
53 |
+
with gr.Blocks() as demo:
|
54 |
+
with gr.Row():
|
55 |
+
with gr.Column(scale=2):
|
56 |
+
prompt_input = gr.Textbox(label="Prompt", placeholder="A cat stretching on the floor, highly detailed, ultra sharp, cinematic, 100mm lens, 8k resolution")
|
57 |
+
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=1)
|
58 |
+
num_inference_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=100, value=20, step=5)
|
59 |
+
sampler = gr.Radio(label="Sampling Method", choices=["ddpm", "Euler a", "Euler", "LMS", "Heun", "DPM2 a", "PLMS"], value="ddpm")
|
60 |
+
generate_btn = gr.Button("Generate", variant="primary")
|
61 |
+
with gr.Column(scale=2):
|
62 |
+
output_image = gr.Image(label="Output", show_label=False, height=512, width=512)
|
63 |
+
|
64 |
+
generate_btn.click(fn=generate_image, inputs=[prompt_input, cfg_scale, num_inference_steps, sampler], outputs=output_image)
|
65 |
+
|
66 |
+
demo.launch()
|
67 |
+
|
68 |
+
if __name__ == "__main__":
|
69 |
+
gradio_interface()
|
sd_inference.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from nanograd.models.stable_diffusion import model_loader
|
2 |
+
from nanograd.models.stable_diffusion import pipeline
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
from pathlib import Path
|
6 |
+
from transformers import CLIPTokenizer
|
7 |
+
import torch
|
8 |
+
|
9 |
+
DEVICE = "cpu"
|
10 |
+
|
11 |
+
ALLOW_CUDA = False
|
12 |
+
ALLOW_MPS = False
|
13 |
+
|
14 |
+
if torch.cuda.is_available() and ALLOW_CUDA:
|
15 |
+
DEVICE = "cuda"
|
16 |
+
elif (torch.has_mps or torch.backends.mps.is_available()) and ALLOW_MPS:
|
17 |
+
DEVICE = "mps"
|
18 |
+
print(f"Using device: {DEVICE}")
|
19 |
+
|
20 |
+
tokenizer = CLIPTokenizer("nanograd\models\stable_diffusion\sd_data\\tokenizer_vocab.json", merges_file="nanograd\models\stable_diffusion\sd_data\\tokenizer_merges.txt")
|
21 |
+
model_file = "nanograd\models\stable_diffusion\sd_data\\v1-5-pruned-emaonly.ckpt"
|
22 |
+
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)
|
23 |
+
|
24 |
+
## TEXT TO IMAGE
|
25 |
+
|
26 |
+
prompt = input("Enter your prompt: ")
|
27 |
+
# prompt = "A cat stretching on the floor, highly detailed, ultra sharp, cinematic, 100mm lens, 8k resolution."
|
28 |
+
uncond_prompt = ""
|
29 |
+
do_cfg = True
|
30 |
+
cfg_scale = 8 # min: 1, max: 14
|
31 |
+
|
32 |
+
## IMAGE TO IMAGE
|
33 |
+
|
34 |
+
input_image = None
|
35 |
+
# Comment to disable image to image
|
36 |
+
# image_path = "../images/dog.jpg"
|
37 |
+
# input_image = Image.open(image_path)
|
38 |
+
# Higher values means more noise will be added to the input image, so the result will further from the input image.
|
39 |
+
strength = 0.9
|
40 |
+
|
41 |
+
## SAMPLER
|
42 |
+
|
43 |
+
sampler = "ddpm"
|
44 |
+
num_inference_steps = 50
|
45 |
+
seed = 42
|
46 |
+
|
47 |
+
|
48 |
+
def run():
|
49 |
+
output_image = pipeline.generate(
|
50 |
+
prompt=prompt,
|
51 |
+
uncond_prompt=uncond_prompt,
|
52 |
+
input_image=input_image,
|
53 |
+
strength=strength,
|
54 |
+
do_cfg=do_cfg,
|
55 |
+
cfg_scale=cfg_scale,
|
56 |
+
sampler_name=sampler,
|
57 |
+
n_inference_steps=num_inference_steps,
|
58 |
+
seed=seed,
|
59 |
+
models=models,
|
60 |
+
device=DEVICE,
|
61 |
+
idle_device="cpu",
|
62 |
+
tokenizer=tokenizer,
|
63 |
+
)
|
64 |
+
|
65 |
+
output_image = Image.fromarray(output_image)
|
66 |
+
output_path = "nanograd\models\stable_diffusion\output\\c.png"
|
67 |
+
output_image.save(output_path)
|
68 |
+
print(f"Image saved as {output_path}")
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
run()
|
72 |
+
|