Esmail-AGumaan commited on
Commit
64e1ee8
1 Parent(s): 33e1004

Upload 13 files

Browse files
Files changed (13) hide show
  1. attention.py +77 -0
  2. clip.py +64 -0
  3. ddpm.py +112 -0
  4. decoder.py +100 -0
  5. demo.ipynb +0 -0
  6. diffusion.py +213 -0
  7. encoder.py +56 -0
  8. gaussing_diffusion.py +814 -0
  9. model_converter.py +0 -0
  10. model_loader.py +28 -0
  11. pipeline.py +141 -0
  12. sd_gradio.py +69 -0
  13. sd_inference.py +72 -0
attention.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ import math
5
+
6
+ class SelfAttention(nn.Module):
7
+ def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
8
+ super().__init__()
9
+ self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
10
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
11
+ self.n_heads = n_heads
12
+ self.d_head = d_embed // n_heads
13
+
14
+ def forward(self, x, causal_mask=False):
15
+ input_shape = x.shape
16
+ batch_size, sequence_length, d_embed = input_shape
17
+ interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
18
+
19
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
20
+ q = q.view(interim_shape).transpose(1, 2)
21
+ k = k.view(interim_shape).transpose(1, 2)
22
+ v = v.view(interim_shape).transpose(1, 2)
23
+
24
+ weight = q @ k.transpose(-1, -2)
25
+
26
+ if causal_mask:
27
+ mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
28
+ weight.masked_fill_(mask, -torch.inf)
29
+
30
+ weight /= math.sqrt(self.d_head)
31
+
32
+ weight = F.softmax(weight, dim=-1)
33
+ output = weight @ v
34
+ output = output.transpose(1, 2)
35
+
36
+ output = output.reshape(input_shape)
37
+ output = self.out_proj(output)
38
+ return output
39
+
40
+ class CrossAttention(nn.Module):
41
+ def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
42
+ super().__init__()
43
+ self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
44
+ self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
45
+ self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
46
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
47
+ self.n_heads = n_heads
48
+ self.d_head = d_embed // n_heads
49
+
50
+ def forward(self, x, y):
51
+ input_shape = x.shape
52
+ batch_size, sequence_length, d_embed = input_shape
53
+ interim_shape = (batch_size, -1, self.n_heads, self.d_head)
54
+
55
+ q = self.q_proj(x)
56
+ k = self.k_proj(y)
57
+ v = self.v_proj(y)
58
+
59
+ q = q.view(interim_shape).transpose(1, 2)
60
+ k = k.view(interim_shape).transpose(1, 2)
61
+ v = v.view(interim_shape).transpose(1, 2)
62
+
63
+ weight = q @ k.transpose(-1, -2)
64
+
65
+ weight /= math.sqrt(self.d_head)
66
+
67
+ weight = F.softmax(weight, dim=-1)
68
+
69
+ output = weight @ v
70
+
71
+ output = output.transpose(1, 2).contiguous()
72
+
73
+ output = output.view(input_shape)
74
+
75
+ output = self.out_proj(output)
76
+
77
+ return output
clip.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from nanograd.models.stable_diffusion.attention import SelfAttention
5
+
6
+ class CLIPEmbedding(nn.Module):
7
+ def __init__(self, n_vocab: int, n_embd: int, n_token: int):
8
+ super().__init__()
9
+
10
+ self.token_embedding = nn.Embedding(n_vocab, n_embd)
11
+ self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))
12
+
13
+ def forward(self, tokens):
14
+ x = self.token_embedding(tokens)
15
+ x += self.position_embedding
16
+
17
+ return x
18
+
19
+ class CLIPLayer(nn.Module):
20
+ def __init__(self, n_head: int, n_embd: int):
21
+ super().__init__()
22
+ self.layernorm_1 = nn.LayerNorm(n_embd)
23
+ self.attention = SelfAttention(n_head, n_embd)
24
+ self.layernorm_2 = nn.LayerNorm(n_embd)
25
+ self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
26
+ self.linear_2 = nn.Linear(4 * n_embd, n_embd)
27
+
28
+ def forward(self, x):
29
+ residue = x
30
+ x = self.layernorm_1(x)
31
+ x = self.attention(x, causal_mask=True)
32
+ x += residue
33
+
34
+ residue = x
35
+ x = self.layernorm_2(x)
36
+ x = self.linear_1(x)
37
+
38
+ x = x * torch.sigmoid(1.702 * x)
39
+ x = self.linear_2(x)
40
+ x += residue
41
+
42
+ return x
43
+
44
+ class CLIP(nn.Module):
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.embedding = CLIPEmbedding(49408, 768, 77)
48
+
49
+ self.layers = nn.ModuleList([
50
+ CLIPLayer(12, 768) for i in range(12)
51
+ ])
52
+
53
+ self.layernorm = nn.LayerNorm(768)
54
+
55
+ def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
56
+ tokens = tokens.type(torch.long)
57
+
58
+ state = self.embedding(tokens)
59
+
60
+ for layer in self.layers:
61
+ state = layer(state)
62
+ output = self.layernorm(state)
63
+
64
+ return output
ddpm.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ class DDPMSampler: # Denoising Diffusion Probabilistic Models Sampler
5
+
6
+ def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
7
+ # Params "beta_start" and "beta_end" taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8
8
+ # For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)
9
+ self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
10
+ self.alphas = 1.0 - self.betas
11
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
12
+ self.one = torch.tensor(1.0)
13
+
14
+ self.generator = generator
15
+
16
+ self.num_train_timesteps = num_training_steps
17
+ self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
18
+
19
+ def set_inference_timesteps(self, num_inference_steps=50):
20
+ self.num_inference_steps = num_inference_steps
21
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
22
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
23
+ self.timesteps = torch.from_numpy(timesteps)
24
+
25
+ def _get_previous_timestep(self, timestep: int) -> int:
26
+ prev_t = timestep - self.num_train_timesteps // self.num_inference_steps
27
+ return prev_t
28
+
29
+ def _get_variance(self, timestep: int) -> torch.Tensor:
30
+ prev_t = self._get_previous_timestep(timestep)
31
+
32
+ alpha_prod_t = self.alphas_cumprod[timestep]
33
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
34
+ current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
35
+
36
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
37
+
38
+ variance = torch.clamp(variance, min=1e-20)
39
+
40
+ return variance
41
+
42
+ def set_strength(self, strength=1):
43
+ """
44
+ Set how much noise to add to the input image.
45
+ More noise (strength ~ 1) means that the output will be further from the input image.
46
+ Less noise (strength ~ 0) means that the output will be closer to the input image.
47
+ """
48
+ # start_step is the number of noise levels to skip
49
+ start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
50
+ self.timesteps = self.timesteps[start_step:]
51
+ self.start_step = start_step
52
+
53
+ def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
54
+ t = timestep
55
+ prev_t = self._get_previous_timestep(t)
56
+
57
+ alpha_prod_t = self.alphas_cumprod[t]
58
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
59
+ beta_prod_t = 1 - alpha_prod_t
60
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
61
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
62
+ current_beta_t = 1 - current_alpha_t
63
+
64
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
65
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
66
+
67
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
68
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
69
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
70
+
71
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
72
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
73
+
74
+ variance = 0
75
+ if t > 0:
76
+ device = model_output.device
77
+ noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
78
+ # Compute the variance as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf
79
+ variance = (self._get_variance(t) ** 0.5) * noise
80
+
81
+ pred_prev_sample = pred_prev_sample + variance
82
+
83
+ return pred_prev_sample
84
+
85
+ def add_noise(
86
+ self,
87
+ original_samples: torch.FloatTensor,
88
+ timesteps: torch.IntTensor,
89
+ ) -> torch.FloatTensor:
90
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
91
+ timesteps = timesteps.to(original_samples.device)
92
+
93
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
94
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
95
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
96
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
97
+
98
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
99
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
100
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
101
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
102
+
103
+ # Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
104
+ # Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
105
+ # here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
106
+ noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
107
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
108
+ return noisy_samples
109
+
110
+
111
+
112
+
decoder.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from nanograd.models.stable_diffusion.attention import SelfAttention
5
+
6
+ class VAE_AttentionBlock(nn.Module):
7
+ def __init__(self, channels):
8
+ super().__init__()
9
+ self.groupnorm = nn.GroupNorm(32, channels)
10
+ self.attention = SelfAttention(1, channels)
11
+
12
+ def forward(self, x):
13
+ residue = x
14
+ x = self.groupnorm(x)
15
+ n, c, h, w = x.shape
16
+ x = x.view((n, c, h * w))
17
+ x = x.transpose(-1, -2)
18
+ x = self.attention(x)
19
+ x = x.transpose(-1, -2)
20
+ x = x.view((n, c, h, w))
21
+ x += residue
22
+
23
+ return x
24
+
25
+ class VAE_ResidualBlock(nn.Module):
26
+ def __init__(self, in_channels, out_channels):
27
+ super().__init__()
28
+ self.groupnorm_1 = nn.GroupNorm(32, in_channels)
29
+ self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
30
+
31
+ self.groupnorm_2 = nn.GroupNorm(32, out_channels)
32
+ self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
33
+
34
+ if in_channels == out_channels:
35
+ self.residual_layer = nn.Identity()
36
+ else:
37
+ self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
38
+
39
+ def forward(self, x):
40
+ residue = x
41
+ x = self.groupnorm_1(x)
42
+ x = F.silu(x)
43
+ x = self.conv_1(x)
44
+ x = self.groupnorm_2(x)
45
+ x = F.silu(x)
46
+ x = self.conv_2(x)
47
+
48
+ return x + self.residual_layer(residue)
49
+
50
+ class VAE_Decoder(nn.Sequential):
51
+ def __init__(self):
52
+ super().__init__(
53
+ nn.Conv2d(4, 4, kernel_size=1, padding=0),
54
+ nn.Conv2d(4, 512, kernel_size=3, padding=1),
55
+ VAE_ResidualBlock(512, 512),
56
+ VAE_AttentionBlock(512),
57
+ VAE_ResidualBlock(512, 512),
58
+ VAE_ResidualBlock(512, 512),
59
+ VAE_ResidualBlock(512, 512),
60
+ VAE_ResidualBlock(512, 512),
61
+
62
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 4, Width / 4)
63
+ nn.Upsample(scale_factor=2),
64
+
65
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
66
+
67
+ VAE_ResidualBlock(512, 512),
68
+ VAE_ResidualBlock(512, 512),
69
+ VAE_ResidualBlock(512, 512),
70
+
71
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 2, Width / 2)
72
+ nn.Upsample(scale_factor=2),
73
+
74
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
75
+
76
+ VAE_ResidualBlock(512, 256),
77
+ VAE_ResidualBlock(256, 256),
78
+ VAE_ResidualBlock(256, 256),
79
+
80
+ nn.Upsample(scale_factor=2),
81
+
82
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
83
+
84
+ VAE_ResidualBlock(256, 128),
85
+ VAE_ResidualBlock(128, 128),
86
+ VAE_ResidualBlock(128, 128),
87
+
88
+ nn.GroupNorm(32, 128),
89
+
90
+ nn.SiLU(),
91
+
92
+ nn.Conv2d(128, 3, kernel_size=3, padding=1),
93
+ )
94
+
95
+ def forward(self, x):
96
+ x /= 0.18215
97
+
98
+ for module in self:
99
+ x = module(x)
100
+ return x
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
diffusion.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from nanograd.models.stable_diffusion.attention import SelfAttention, CrossAttention
5
+
6
+ class TimeEmbedding(nn.Module):
7
+ def __init__(self, n_embd):
8
+ super().__init__()
9
+ self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
10
+ self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
11
+
12
+ def forward(self, x):
13
+ x = self.linear_1(x)
14
+ x = F.silu(x)
15
+ x = self.linear_2(x)
16
+
17
+ return x
18
+
19
+ class UNET_ResidualBlock(nn.Module):
20
+ def __init__(self, in_channels, out_channels, n_time=1280):
21
+ super().__init__()
22
+ self.groupnorm_feature = nn.GroupNorm(32, in_channels)
23
+ self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
24
+ self.linear_time = nn.Linear(n_time, out_channels)
25
+
26
+ self.groupnorm_merged = nn.GroupNorm(32, out_channels)
27
+ self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
28
+
29
+ if in_channels == out_channels:
30
+ self.residual_layer = nn.Identity()
31
+ else:
32
+ self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
33
+
34
+ def forward(self, feature, time):
35
+ residue = feature
36
+
37
+ feature = self.groupnorm_feature(feature)
38
+ feature = F.silu(feature)
39
+ feature = self.conv_feature(feature)
40
+
41
+ time = F.silu(time)
42
+
43
+ time = self.linear_time(time)
44
+ merged = feature + time.unsqueeze(-1).unsqueeze(-1)
45
+ merged = self.groupnorm_merged(merged)
46
+ merged = F.silu(merged)
47
+ merged = self.conv_merged(merged)
48
+
49
+ return merged + self.residual_layer(residue)
50
+
51
+ class UNET_AttentionBlock(nn.Module):
52
+ def __init__(self, n_head: int, n_embd: int, d_context=768):
53
+ super().__init__()
54
+ channels = n_head * n_embd
55
+
56
+ self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
57
+ self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
58
+
59
+ self.layernorm_1 = nn.LayerNorm(channels)
60
+ self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
61
+ self.layernorm_2 = nn.LayerNorm(channels)
62
+ self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
63
+ self.layernorm_3 = nn.LayerNorm(channels)
64
+ self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
65
+ self.linear_geglu_2 = nn.Linear(4 * channels, channels)
66
+
67
+ self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
68
+
69
+ def forward(self, x, context):
70
+ residue_long = x
71
+
72
+ x = self.groupnorm(x)
73
+ x = self.conv_input(x)
74
+
75
+ n, c, h, w = x.shape
76
+ x = x.view((n, c, h * w))
77
+
78
+ x = x.transpose(-1, -2)
79
+
80
+ residue_short = x
81
+
82
+ x = self.layernorm_1(x)
83
+ x = self.attention_1(x)
84
+ x += residue_short
85
+
86
+ residue_short = x
87
+
88
+ x = self.layernorm_2(x)
89
+ x = self.attention_2(x, context)
90
+
91
+ x += residue_short
92
+
93
+ residue_short = x
94
+
95
+ x = self.layernorm_3(x)
96
+
97
+ # GeGLU as implemented in the original code: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/attention.py#L37C10-L37C10
98
+ x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
99
+ x = x * F.gelu(gate)
100
+ x = self.linear_geglu_2(x)
101
+ x += residue_short
102
+ x = x.transpose(-1, -2)
103
+
104
+ x = x.view((n, c, h, w))
105
+
106
+ return self.conv_output(x) + residue_long
107
+
108
+ class Upsample(nn.Module):
109
+ def __init__(self, channels):
110
+ super().__init__()
111
+ self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
112
+
113
+ def forward(self, x):
114
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
115
+ return self.conv(x)
116
+
117
+ class SwitchSequential(nn.Sequential):
118
+ def forward(self, x, context, time):
119
+ for layer in self:
120
+ if isinstance(layer, UNET_AttentionBlock):
121
+ x = layer(x, context)
122
+ elif isinstance(layer, UNET_ResidualBlock):
123
+ x = layer(x, time)
124
+ else:
125
+ x = layer(x)
126
+ return x
127
+
128
+ class UNET(nn.Module):
129
+ def __init__(self):
130
+ super().__init__()
131
+ self.encoders = nn.ModuleList([
132
+ SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
133
+ SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
134
+ SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
135
+ SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
136
+ SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
137
+ SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
138
+ SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
139
+ SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
140
+ SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
141
+ SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
142
+ SwitchSequential(UNET_ResidualBlock(1280, 1280)),
143
+ SwitchSequential(UNET_ResidualBlock(1280, 1280)),
144
+ ])
145
+
146
+ self.bottleneck = SwitchSequential(
147
+ UNET_ResidualBlock(1280, 1280),
148
+ UNET_AttentionBlock(8, 160),
149
+ UNET_ResidualBlock(1280, 1280),
150
+ )
151
+
152
+ self.decoders = nn.ModuleList([
153
+ SwitchSequential(UNET_ResidualBlock(2560, 1280)),
154
+ SwitchSequential(UNET_ResidualBlock(2560, 1280)),
155
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
156
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
157
+
158
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
159
+ SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
160
+
161
+ SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
162
+ SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
163
+
164
+ SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
165
+ SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
166
+ SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
167
+
168
+ SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
169
+ ])
170
+
171
+ def forward(self, x, context, time):
172
+ skip_connections = []
173
+ for layers in self.encoders:
174
+ x = layers(x, context, time)
175
+ skip_connections.append(x)
176
+
177
+ x = self.bottleneck(x, context, time)
178
+
179
+ for layers in self.decoders:
180
+ x = torch.cat((x, skip_connections.pop()), dim=1)
181
+ x = layers(x, context, time)
182
+
183
+ return x
184
+
185
+
186
+ class UNET_OutputLayer(nn.Module):
187
+ def __init__(self, in_channels, out_channels):
188
+ super().__init__()
189
+ self.groupnorm = nn.GroupNorm(32, in_channels)
190
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
191
+
192
+ def forward(self, x):
193
+ x = self.groupnorm(x)
194
+ x = F.silu(x)
195
+ x = self.conv(x)
196
+
197
+ return x
198
+
199
+ class Diffusion(nn.Module):
200
+ def __init__(self):
201
+ super().__init__()
202
+ self.time_embedding = TimeEmbedding(320)
203
+ self.unet = UNET()
204
+ self.final = UNET_OutputLayer(320, 4)
205
+
206
+ def forward(self, latent, context, time):
207
+ time = self.time_embedding(time)
208
+
209
+ output = self.unet(latent, context, time)
210
+
211
+ output = self.final(output)
212
+
213
+ return output
encoder.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from nanograd.models.stable_diffusion.decoder import VAE_AttentionBlock, VAE_ResidualBlock
5
+
6
+ class VAE_Encoder(nn.Sequential):
7
+ def __init__(self):
8
+ super().__init__(
9
+ nn.Conv2d(3, 128, kernel_size=3, padding=1),
10
+
11
+ VAE_ResidualBlock(128, 128),
12
+ VAE_ResidualBlock(128, 128),
13
+
14
+ nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
15
+
16
+ VAE_ResidualBlock(128, 256),
17
+ VAE_ResidualBlock(256, 256),
18
+
19
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
20
+
21
+ VAE_ResidualBlock(256, 512),
22
+ VAE_ResidualBlock(512, 512),
23
+
24
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
25
+
26
+ VAE_ResidualBlock(512, 512),
27
+ VAE_ResidualBlock(512, 512),
28
+ VAE_ResidualBlock(512, 512),
29
+ VAE_AttentionBlock(512),
30
+ VAE_ResidualBlock(512, 512),
31
+
32
+ nn.GroupNorm(32, 512),
33
+
34
+ nn.SiLU(),
35
+
36
+ nn.Conv2d(512, 8, kernel_size=3, padding=1),
37
+
38
+ nn.Conv2d(8, 8, kernel_size=1, padding=0),
39
+ )
40
+
41
+ def forward(self, x, noise):
42
+ for module in self:
43
+
44
+ if getattr(module, 'stride', None) == (2, 2):
45
+ x = F.pad(x, (0, 1, 0, 1))
46
+
47
+ x = module(x)
48
+ mean, log_variance = torch.chunk(x, 2, dim=1)
49
+ log_variance = torch.clamp(log_variance, -30, 20)
50
+ variance = log_variance.exp()
51
+ stdev = variance.sqrt()
52
+ x = mean + stdev * noise
53
+ # Constant taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L17C1-L17C1
54
+ x *= 0.18215
55
+
56
+ return x
gaussing_diffusion.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+ import enum
5
+
6
+ class GaussingDistribution:
7
+ def __init__(self, parameters: torch.Tensor) -> None:
8
+ self.mean, log_variance = torch.chunk(parameters, 2, dim=1)
9
+ self.log_variance = torch.clamp(log_variance, -30.0, 20.0)
10
+ self.std = torch.exp(0.5 * self.log_variance)
11
+
12
+ def sample(self):
13
+ return self.mean + self.std * torch.rand_like(self.std)
14
+
15
+ def normal_kl(mean1, logvar1, mean2, logvar2):
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, torch.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ logvar1, logvar2 = [
24
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
25
+ for x in (logvar1, logvar2)
26
+ ]
27
+
28
+ return 0.5 * (
29
+ -1.0
30
+ + logvar2
31
+ - logvar1
32
+ + torch.exp(logvar1 - logvar2)
33
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
34
+ )
35
+
36
+
37
+ def approx_standard_normal_cdf(x):
38
+
39
+ return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))
40
+
41
+
42
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
43
+
44
+ centered_x = x - means
45
+ inv_stdv = torch.exp(-log_scales)
46
+ normalized_x = centered_x * inv_stdv
47
+ log_probs = torch.distributions.Normal(torch.zeros_like(x), torch.ones_like(x)).log_prob(normalized_x)
48
+ return log_probs
49
+
50
+
51
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
52
+
53
+ assert x.shape == means.shape == log_scales.shape
54
+ centered_x = x - means
55
+ inv_stdv = torch.exp(-log_scales)
56
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
57
+ cdf_plus = approx_standard_normal_cdf(plus_in)
58
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
59
+ cdf_min = approx_standard_normal_cdf(min_in)
60
+ log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
61
+ log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
62
+ cdf_delta = cdf_plus - cdf_min
63
+ log_probs = torch.where(
64
+ x < -0.999,
65
+ log_cdf_plus,
66
+ torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
67
+ )
68
+ assert log_probs.shape == x.shape
69
+ return log_probs
70
+
71
+ ################# Gaussing ####################
72
+
73
+ def mean_flat(tensor):
74
+
75
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
76
+
77
+
78
+ class ModelMeanType(enum.Enum):
79
+
80
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
81
+ START_X = enum.auto() # the model predicts x_0
82
+ EPSILON = enum.auto() # the model predicts epsilon
83
+
84
+
85
+ class ModelVarType(enum.Enum):
86
+
87
+ LEARNED = enum.auto()
88
+ FIXED_SMALL = enum.auto()
89
+ FIXED_LARGE = enum.auto()
90
+ LEARNED_RANGE = enum.auto()
91
+
92
+
93
+ class LossType(enum.Enum):
94
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
95
+ RESCALED_MSE = (
96
+ enum.auto()
97
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
98
+ KL = enum.auto() # use the variational lower-bound
99
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
100
+
101
+ def is_vb(self):
102
+ return self == LossType.KL or self == LossType.RESCALED_KL
103
+
104
+
105
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
106
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
107
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
108
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
109
+ return betas
110
+
111
+
112
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
113
+ if beta_schedule == "quad":
114
+ betas = (
115
+ np.linspace(
116
+ beta_start ** 0.5,
117
+ beta_end ** 0.5,
118
+ num_diffusion_timesteps,
119
+ dtype=np.float64,
120
+ )
121
+ ** 2
122
+ )
123
+ elif beta_schedule == "linear":
124
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
125
+ elif beta_schedule == "warmup10":
126
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
127
+ elif beta_schedule == "warmup50":
128
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
129
+ elif beta_schedule == "const":
130
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
131
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
132
+ betas = 1.0 / np.linspace(
133
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
134
+ )
135
+ else:
136
+ raise NotImplementedError(beta_schedule)
137
+ assert betas.shape == (num_diffusion_timesteps,)
138
+ return betas
139
+
140
+
141
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
142
+
143
+ if schedule_name == "linear":
144
+
145
+ scale = 1000 / num_diffusion_timesteps
146
+ return get_beta_schedule(
147
+ "linear",
148
+ beta_start=scale * 0.0001,
149
+ beta_end=scale * 0.02,
150
+ num_diffusion_timesteps=num_diffusion_timesteps,
151
+ )
152
+ elif schedule_name == "squaredcos_cap_v2":
153
+ return betas_for_alpha_bar(
154
+ num_diffusion_timesteps,
155
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
156
+ )
157
+ else:
158
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
159
+
160
+
161
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
162
+
163
+ betas = []
164
+ for i in range(num_diffusion_timesteps):
165
+ t1 = i / num_diffusion_timesteps
166
+ t2 = (i + 1) / num_diffusion_timesteps
167
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
168
+ return np.array(betas)
169
+
170
+
171
+ class GaussianDiffusion:
172
+ def __init__(
173
+ self,
174
+ *,
175
+ betas,
176
+ model_mean_type,
177
+ model_var_type,
178
+ loss_type
179
+ ):
180
+
181
+ self.model_mean_type = model_mean_type
182
+ self.model_var_type = model_var_type
183
+ self.loss_type = loss_type
184
+
185
+ # Use float64 for accuracy.
186
+ betas = np.array(betas, dtype=np.float64)
187
+ self.betas = betas
188
+ assert len(betas.shape) == 1, "betas must be 1-D"
189
+ assert (betas > 0).all() and (betas <= 1).all()
190
+
191
+ self.num_timesteps = int(betas.shape[0])
192
+
193
+ alphas = 1.0 - betas
194
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
195
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
196
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
197
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
198
+
199
+ # calculations for diffusion q(x_t | x_{t-1}) and others
200
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
201
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
202
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
203
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
204
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
205
+
206
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
207
+ self.posterior_variance = (
208
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
209
+ )
210
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
211
+ self.posterior_log_variance_clipped = np.log(
212
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
213
+ ) if len(self.posterior_variance) > 1 else np.array([])
214
+
215
+ self.posterior_mean_coef1 = (
216
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
217
+ )
218
+ self.posterior_mean_coef2 = (
219
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
220
+ )
221
+
222
+ def q_mean_variance(self, x_start, t):
223
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
224
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
225
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
226
+ return mean, variance, log_variance
227
+
228
+ def q_sample(self, x_start, t, noise=None):
229
+ if noise is None:
230
+ noise = torch.randn_like(x_start)
231
+ assert noise.shape == x_start.shape
232
+ return (
233
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
234
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
235
+ )
236
+
237
+ def q_posterior_mean_variance(self, x_start, x_t, t):
238
+ assert x_start.shape == x_t.shape
239
+ posterior_mean = (
240
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
241
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
242
+ )
243
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
244
+ posterior_log_variance_clipped = _extract_into_tensor(
245
+ self.posterior_log_variance_clipped, t, x_t.shape
246
+ )
247
+ assert (
248
+ posterior_mean.shape[0]
249
+ == posterior_variance.shape[0]
250
+ == posterior_log_variance_clipped.shape[0]
251
+ == x_start.shape[0]
252
+ )
253
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
254
+
255
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
256
+ if model_kwargs is None:
257
+ model_kwargs = {}
258
+
259
+ B, C = x.shape[:2]
260
+ assert t.shape == (B,)
261
+ model_output = model(x, t, **model_kwargs)
262
+ if isinstance(model_output, tuple):
263
+ model_output, extra = model_output
264
+ else:
265
+ extra = None
266
+
267
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
268
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
269
+ model_output, model_var_values = torch.split(model_output, C, dim=1)
270
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
271
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
272
+ frac = (model_var_values + 1) / 2
273
+ model_log_variance = frac * max_log + (1 - frac) * min_log
274
+ model_variance = torch.exp(model_log_variance)
275
+ else:
276
+ model_variance, model_log_variance = {
277
+ ModelVarType.FIXED_LARGE: (
278
+ np.append(self.posterior_variance[1], self.betas[1:]),
279
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
280
+ ),
281
+ ModelVarType.FIXED_SMALL: (
282
+ self.posterior_variance,
283
+ self.posterior_log_variance_clipped,
284
+ ),
285
+ }[self.model_var_type]
286
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
287
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
288
+
289
+ def process_xstart(x):
290
+ if denoised_fn is not None:
291
+ x = denoised_fn(x)
292
+ if clip_denoised:
293
+ return x.clamp(-1, 1)
294
+ return x
295
+
296
+ if self.model_mean_type == ModelMeanType.START_X:
297
+ pred_xstart = process_xstart(model_output)
298
+ else:
299
+ pred_xstart = process_xstart(
300
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
301
+ )
302
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
303
+
304
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
305
+ return {
306
+ "mean": model_mean,
307
+ "variance": model_variance,
308
+ "log_variance": model_log_variance,
309
+ "pred_xstart": pred_xstart,
310
+ "extra": extra,
311
+ }
312
+
313
+ def _predict_xstart_from_eps(self, x_t, t, eps):
314
+ assert x_t.shape == eps.shape
315
+ return (
316
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
317
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
318
+ )
319
+
320
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
321
+ return (
322
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
323
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
324
+
325
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
326
+ gradient = cond_fn(x, t, **model_kwargs)
327
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
328
+ return new_mean
329
+
330
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
331
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
332
+
333
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
334
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
335
+
336
+ out = p_mean_var.copy()
337
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
338
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
339
+ return out
340
+
341
+ def p_sample(
342
+ self,
343
+ model,
344
+ x,
345
+ t,
346
+ clip_denoised=True,
347
+ denoised_fn=None,
348
+ cond_fn=None,
349
+ model_kwargs=None,
350
+ ):
351
+ out = self.p_mean_variance(
352
+ model,
353
+ x,
354
+ t,
355
+ clip_denoised=clip_denoised,
356
+ denoised_fn=denoised_fn,
357
+ model_kwargs=model_kwargs,
358
+ )
359
+ noise = torch.randn_like(x)
360
+ nonzero_mask = (
361
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
362
+ ) # no noise when t == 0
363
+ if cond_fn is not None:
364
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
365
+ sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise
366
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
367
+
368
+ def p_sample_loop(
369
+ self,
370
+ model,
371
+ shape,
372
+ noise=None,
373
+ clip_denoised=True,
374
+ denoised_fn=None,
375
+ cond_fn=None,
376
+ model_kwargs=None,
377
+ device=None,
378
+ progress=False,
379
+ ):
380
+ final = None
381
+ for sample in self.p_sample_loop_progressive(
382
+ model,
383
+ shape,
384
+ noise=noise,
385
+ clip_denoised=clip_denoised,
386
+ denoised_fn=denoised_fn,
387
+ cond_fn=cond_fn,
388
+ model_kwargs=model_kwargs,
389
+ device=device,
390
+ progress=progress,
391
+ ):
392
+ final = sample
393
+ return final["sample"]
394
+
395
+ def p_sample_loop_progressive(
396
+ self,
397
+ model,
398
+ shape,
399
+ noise=None,
400
+ clip_denoised=True,
401
+ denoised_fn=None,
402
+ cond_fn=None,
403
+ model_kwargs=None,
404
+ device=None,
405
+ progress=False,
406
+ ):
407
+ if device is None:
408
+ device = next(model.parameters()).device
409
+ assert isinstance(shape, (tuple, list))
410
+ if noise is not None:
411
+ img = noise
412
+ else:
413
+ img = torch.randn(*shape, device=device)
414
+ indices = list(range(self.num_timesteps))[::-1]
415
+
416
+ if progress:
417
+ # Lazy import so that we don't depend on tqdm.
418
+ from tqdm.auto import tqdm
419
+
420
+ indices = tqdm(indices)
421
+
422
+ for i in indices:
423
+ t = torch.tensor([i] * shape[0], device=device)
424
+ with torch.no_grad():
425
+ out = self.p_sample(
426
+ model,
427
+ img,
428
+ t,
429
+ clip_denoised=clip_denoised,
430
+ denoised_fn=denoised_fn,
431
+ cond_fn=cond_fn,
432
+ model_kwargs=model_kwargs,
433
+ )
434
+ yield out
435
+ img = out["sample"]
436
+
437
+ def ddim_sample(
438
+ self,
439
+ model,
440
+ x,
441
+ t,
442
+ clip_denoised=True,
443
+ denoised_fn=None,
444
+ cond_fn=None,
445
+ model_kwargs=None,
446
+ eta=0.0,
447
+ ):
448
+ out = self.p_mean_variance(
449
+ model,
450
+ x,
451
+ t,
452
+ clip_denoised=clip_denoised,
453
+ denoised_fn=denoised_fn,
454
+ model_kwargs=model_kwargs,
455
+ )
456
+ if cond_fn is not None:
457
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
458
+
459
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
460
+
461
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
462
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
463
+ sigma = (
464
+ eta
465
+ * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
466
+ * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
467
+ )
468
+ # Equation 12.
469
+ noise = torch.randn_like(x)
470
+ mean_pred = (
471
+ out["pred_xstart"] * torch.sqrt(alpha_bar_prev)
472
+ + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
473
+ )
474
+ nonzero_mask = (
475
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
476
+ ) # no noise when t == 0
477
+ sample = mean_pred + nonzero_mask * sigma * noise
478
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
479
+
480
+ def ddim_reverse_sample(
481
+ self,
482
+ model,
483
+ x,
484
+ t,
485
+ clip_denoised=True,
486
+ denoised_fn=None,
487
+ cond_fn=None,
488
+ model_kwargs=None,
489
+ eta=0.0,
490
+ ):
491
+
492
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
493
+ out = self.p_mean_variance(
494
+ model,
495
+ x,
496
+ t,
497
+ clip_denoised=clip_denoised,
498
+ denoised_fn=denoised_fn,
499
+ model_kwargs=model_kwargs,
500
+ )
501
+ if cond_fn is not None:
502
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
503
+ eps = (
504
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
505
+ - out["pred_xstart"]
506
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
507
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
508
+
509
+ # Equation 12. reversed
510
+ mean_pred = out["pred_xstart"] * torch.sqrt(alpha_bar_next) + torch.sqrt(1 - alpha_bar_next) * eps
511
+
512
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
513
+
514
+ def ddim_sample_loop(
515
+ self,
516
+ model,
517
+ shape,
518
+ noise=None,
519
+ clip_denoised=True,
520
+ denoised_fn=None,
521
+ cond_fn=None,
522
+ model_kwargs=None,
523
+ device=None,
524
+ progress=False,
525
+ eta=0.0,
526
+ ):
527
+ final = None
528
+ for sample in self.ddim_sample_loop_progressive(
529
+ model,
530
+ shape,
531
+ noise=noise,
532
+ clip_denoised=clip_denoised,
533
+ denoised_fn=denoised_fn,
534
+ cond_fn=cond_fn,
535
+ model_kwargs=model_kwargs,
536
+ device=device,
537
+ progress=progress,
538
+ eta=eta,
539
+ ):
540
+ final = sample
541
+ return final["sample"]
542
+
543
+ def ddim_sample_loop_progressive(
544
+ self,
545
+ model,
546
+ shape,
547
+ noise=None,
548
+ clip_denoised=True,
549
+ denoised_fn=None,
550
+ cond_fn=None,
551
+ model_kwargs=None,
552
+ device=None,
553
+ progress=False,
554
+ eta=0.0,
555
+ ):
556
+ if device is None:
557
+ device = next(model.parameters()).device
558
+ assert isinstance(shape, (tuple, list))
559
+ if noise is not None:
560
+ img = noise
561
+ else:
562
+ img = torch.randn(*shape, device=device)
563
+ indices = list(range(self.num_timesteps))[::-1]
564
+
565
+ if progress:
566
+ # Lazy import so that we don't depend on tqdm.
567
+ from tqdm.auto import tqdm
568
+
569
+ indices = tqdm(indices)
570
+
571
+ for i in indices:
572
+ t = torch.tensor([i] * shape[0], device=device)
573
+ with torch.no_grad():
574
+ out = self.ddim_sample(
575
+ model,
576
+ img,
577
+ t,
578
+ clip_denoised=clip_denoised,
579
+ denoised_fn=denoised_fn,
580
+ cond_fn=cond_fn,
581
+ model_kwargs=model_kwargs,
582
+ eta=eta,
583
+ )
584
+ yield out
585
+ img = out["sample"]
586
+
587
+ def _vb_terms_bpd(
588
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
589
+ ):
590
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
591
+ x_start=x_start, x_t=x_t, t=t
592
+ )
593
+ out = self.p_mean_variance(
594
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
595
+ )
596
+ kl = normal_kl(
597
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
598
+ )
599
+ kl = mean_flat(kl) / np.log(2.0)
600
+
601
+ decoder_nll = -discretized_gaussian_log_likelihood(
602
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
603
+ )
604
+ assert decoder_nll.shape == x_start.shape
605
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
606
+
607
+ # At the first timestep return the decoder NLL,
608
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
609
+ output = torch.where((t == 0), decoder_nll, kl)
610
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
611
+
612
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
613
+
614
+ if model_kwargs is None:
615
+ model_kwargs = {}
616
+ if noise is None:
617
+ noise = torch.randn_like(x_start)
618
+ x_t = self.q_sample(x_start, t, noise=noise)
619
+
620
+ terms = {}
621
+
622
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
623
+ terms["loss"] = self._vb_terms_bpd(
624
+ model=model,
625
+ x_start=x_start,
626
+ x_t=x_t,
627
+ t=t,
628
+ clip_denoised=False,
629
+ model_kwargs=model_kwargs,
630
+ )["output"]
631
+ if self.loss_type == LossType.RESCALED_KL:
632
+ terms["loss"] *= self.num_timesteps
633
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
634
+ model_output = model(x_t, t, **model_kwargs)
635
+
636
+ if self.model_var_type in [
637
+ ModelVarType.LEARNED,
638
+ ModelVarType.LEARNED_RANGE,
639
+ ]:
640
+ B, C = x_t.shape[:2]
641
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
642
+ model_output, model_var_values = torch.split(model_output, C, dim=1)
643
+
644
+ frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
645
+ terms["vb"] = self._vb_terms_bpd(
646
+ model=lambda *args, r=frozen_out: r,
647
+ x_start=x_start,
648
+ x_t=x_t,
649
+ t=t,
650
+ clip_denoised=False,
651
+ )["output"]
652
+ if self.loss_type == LossType.RESCALED_MSE:
653
+
654
+ terms["vb"] *= self.num_timesteps / 1000.0
655
+
656
+ target = {
657
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
658
+ x_start=x_start, x_t=x_t, t=t
659
+ )[0],
660
+ ModelMeanType.START_X: x_start,
661
+ ModelMeanType.EPSILON: noise,
662
+ }[self.model_mean_type]
663
+ assert model_output.shape == target.shape == x_start.shape
664
+ terms["mse"] = mean_flat((target - model_output) ** 2)
665
+ if "vb" in terms:
666
+ terms["loss"] = terms["mse"] + terms["vb"]
667
+ else:
668
+ terms["loss"] = terms["mse"]
669
+ else:
670
+ raise NotImplementedError(self.loss_type)
671
+
672
+ return terms
673
+
674
+ def _prior_bpd(self, x_start):
675
+ batch_size = x_start.shape[0]
676
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
677
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
678
+ kl_prior = normal_kl(
679
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
680
+ )
681
+ return mean_flat(kl_prior) / np.log(2.0)
682
+
683
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
684
+
685
+ device = x_start.device
686
+ batch_size = x_start.shape[0]
687
+
688
+ vb = []
689
+ xstart_mse = []
690
+ mse = []
691
+ for t in list(range(self.num_timesteps))[::-1]:
692
+ t_batch = torch.tensor([t] * batch_size, device=device)
693
+ noise = torch.randn_like(x_start)
694
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
695
+ # Calculate VLB term at the current timestep
696
+ with torch.no_grad():
697
+ out = self._vb_terms_bpd(
698
+ model,
699
+ x_start=x_start,
700
+ x_t=x_t,
701
+ t=t_batch,
702
+ clip_denoised=clip_denoised,
703
+ model_kwargs=model_kwargs,
704
+ )
705
+ vb.append(out["output"])
706
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
707
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
708
+ mse.append(mean_flat((eps - noise) ** 2))
709
+
710
+ vb = torch.stack(vb, dim=1)
711
+ xstart_mse = torch.stack(xstart_mse, dim=1)
712
+ mse = torch.stack(mse, dim=1)
713
+
714
+ prior_bpd = self._prior_bpd(x_start)
715
+ total_bpd = vb.sum(dim=1) + prior_bpd
716
+ return {
717
+ "total_bpd": total_bpd,
718
+ "prior_bpd": prior_bpd,
719
+ "vb": vb,
720
+ "xstart_mse": xstart_mse,
721
+ "mse": mse,
722
+ }
723
+
724
+
725
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
726
+ res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
727
+ while len(res.shape) < len(broadcast_shape):
728
+ res = res[..., None]
729
+ return res + torch.zeros(broadcast_shape, device=timesteps.device)
730
+
731
+ ############################### Denoising Diffusion Probabilistic Model###################################
732
+ class DDPMSampler:
733
+ def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
734
+ self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
735
+ self.alphas = 1.0 - self.betas
736
+ self.alphas_cumprod = torch.cumprod(self.alphas, d_model=0)
737
+ self.one = torch.tensor(1.0)
738
+ self.generator = generator
739
+ self.num_train_timesteps = num_training_steps
740
+ self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
741
+
742
+ def set_inference_timesteps(self, num_inference_steps=50):
743
+ self.num_inference_steps = num_inference_steps
744
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
745
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
746
+ self.timesteps = torch.from_numpy(timesteps)
747
+
748
+ def _get_previous_timestep(self, timestep: int) -> int:
749
+ prev_t = timestep - self.num_train_timesteps // self.num_inference_steps
750
+ return prev_t
751
+
752
+ def _get_variance(self, timestep: int) -> torch.Tensor:
753
+ prev_t = self._get_previous_timestep(timestep)
754
+ alpha_prod_t = self.alphas_cumprod[timestep]
755
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
756
+ current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
757
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
758
+ variance = torch.clamp(variance, min=1e-20)
759
+ return variance
760
+
761
+ def set_strength(self, strength=1):
762
+ """
763
+ Set how much noise to add to the input image.
764
+ More noise (strength ~ 1) means that the output will be further from the input image.
765
+ Less noise (strength ~ 0) means that the output will be closer to the input image.
766
+ """
767
+
768
+ start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
769
+ self.timesteps = self.timesteps[start_step:]
770
+ self.start_step = start_step
771
+
772
+ def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
773
+ t = timestep
774
+ prev_t = self._get_previous_timestep(t)
775
+ alpha_prod_t = self.alphas_cumprod[t]
776
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
777
+ beta_prod_t = 1 - alpha_prod_t
778
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
779
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
780
+ current_beta_t = 1 - current_alpha_t
781
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
782
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
783
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
784
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
785
+ variance = 0
786
+ if t > 0:
787
+ device = model_output.device
788
+ noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
789
+
790
+ variance = (self._get_variance(t) ** 0.5) * noise
791
+ pred_prev_sample = pred_prev_sample + variance
792
+
793
+ return pred_prev_sample
794
+
795
+ def add_noise(
796
+ self,
797
+ original_samples: torch.FloatTensor,
798
+ timesteps: torch.IntTensor,
799
+ ) -> torch.FloatTensor:
800
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
801
+ timesteps = timesteps.to(original_samples.device)
802
+
803
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
804
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
805
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
806
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
807
+
808
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
809
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
810
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
811
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
812
+ noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
813
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
814
+ return noisy_samples
model_converter.py ADDED
The diff for this file is too large to render. See raw diff
 
model_loader.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nanograd.models.stable_diffusion.clip import CLIP
2
+ from nanograd.models.stable_diffusion.encoder import VAE_Encoder
3
+ from nanograd.models.stable_diffusion.decoder import VAE_Decoder
4
+ from nanograd.models.stable_diffusion.diffusion import Diffusion
5
+
6
+ from nanograd.models.stable_diffusion import model_converter
7
+
8
+ def preload_models_from_standard_weights(ckpt_path, device):
9
+ state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
10
+
11
+ encoder = VAE_Encoder().to(device)
12
+ encoder.load_state_dict(state_dict['encoder'], strict=True)
13
+
14
+ decoder = VAE_Decoder().to(device)
15
+ decoder.load_state_dict(state_dict['decoder'], strict=True)
16
+
17
+ diffusion = Diffusion().to(device)
18
+ diffusion.load_state_dict(state_dict['diffusion'], strict=True)
19
+
20
+ clip = CLIP().to(device)
21
+ clip.load_state_dict(state_dict['clip'], strict=True)
22
+
23
+ return {
24
+ 'clip': clip,
25
+ 'encoder': encoder,
26
+ 'decoder': decoder,
27
+ 'diffusion': diffusion,
28
+ }
pipeline.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from nanograd.models.stable_diffusion.ddpm import DDPMSampler
5
+
6
+ WIDTH = 512
7
+ HEIGHT = 512
8
+ LATENTS_WIDTH = WIDTH // 8
9
+ LATENTS_HEIGHT = HEIGHT // 8
10
+
11
+ def generate(
12
+ prompt,
13
+ uncond_prompt=None,
14
+ input_image=None,
15
+ strength=0.8,
16
+ do_cfg=True,
17
+ cfg_scale=7.5,
18
+ sampler_name="ddpm",
19
+ n_inference_steps=50,
20
+ models={},
21
+ seed=None,
22
+ device=None,
23
+ idle_device=None,
24
+ tokenizer=None,
25
+ ):
26
+ with torch.no_grad():
27
+ if not 0 < strength <= 1:
28
+ raise ValueError("strength must be between 0 and 1")
29
+
30
+ if idle_device:
31
+ to_idle = lambda x: x.to(idle_device)
32
+ else:
33
+ to_idle = lambda x: x
34
+
35
+ generator = torch.Generator(device=device)
36
+ if seed is None:
37
+ generator.seed()
38
+ else:
39
+ generator.manual_seed(seed)
40
+
41
+ clip = models["clip"]
42
+ clip.to(device)
43
+
44
+ if do_cfg:
45
+ cond_tokens = tokenizer.batch_encode_plus(
46
+ [prompt], padding="max_length", max_length=77
47
+ ).input_ids
48
+ cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
49
+ cond_context = clip(cond_tokens)
50
+ uncond_tokens = tokenizer.batch_encode_plus(
51
+ [uncond_prompt], padding="max_length", max_length=77
52
+ ).input_ids
53
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
54
+ uncond_context = clip(uncond_tokens)
55
+ context = torch.cat([cond_context, uncond_context])
56
+ else:
57
+ tokens = tokenizer.batch_encode_plus(
58
+ [prompt], padding="max_length", max_length=77
59
+ ).input_ids
60
+ tokens = torch.tensor(tokens, dtype=torch.long, device=device)
61
+ context = clip(tokens)
62
+ to_idle(clip)
63
+
64
+ if sampler_name == "ddpm":
65
+ sampler = DDPMSampler(generator)
66
+ sampler.set_inference_timesteps(n_inference_steps)
67
+ else:
68
+ raise ValueError("Unknown sampler value %s. ")
69
+
70
+ latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
71
+
72
+ if input_image:
73
+ encoder = models["encoder"]
74
+ encoder.to(device)
75
+
76
+ input_image_tensor = input_image.resize((WIDTH, HEIGHT))
77
+ input_image_tensor = np.array(input_image_tensor)
78
+ input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
79
+ input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
80
+ input_image_tensor = input_image_tensor.unsqueeze(0)
81
+ input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
82
+
83
+ encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
84
+ latents = encoder(input_image_tensor, encoder_noise)
85
+
86
+ sampler.set_strength(strength=strength)
87
+ latents = sampler.add_noise(latents, sampler.timesteps[0])
88
+
89
+ to_idle(encoder)
90
+ else:
91
+ latents = torch.randn(latents_shape, generator=generator, device=device)
92
+
93
+ diffusion = models["diffusion"]
94
+ diffusion.to(device)
95
+
96
+ timesteps = tqdm(sampler.timesteps)
97
+ for i, timestep in enumerate(timesteps):
98
+ time_embedding = get_time_embedding(timestep).to(device)
99
+
100
+ model_input = latents
101
+
102
+ if do_cfg:
103
+ model_input = model_input.repeat(2, 1, 1, 1)
104
+
105
+ model_output = diffusion(model_input, context, time_embedding)
106
+
107
+ if do_cfg:
108
+ output_cond, output_uncond = model_output.chunk(2)
109
+ model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
110
+
111
+ latents = sampler.step(timestep, latents, model_output)
112
+
113
+ to_idle(diffusion)
114
+
115
+ decoder = models["decoder"]
116
+ decoder.to(device)
117
+ images = decoder(latents)
118
+ to_idle(decoder)
119
+
120
+ images = rescale(images, (-1, 1), (0, 255), clamp=True)
121
+ images = images.permute(0, 2, 3, 1)
122
+ images = images.to("cpu", torch.uint8).numpy()
123
+ return images[0]
124
+
125
+ def rescale(x, old_range, new_range, clamp=False):
126
+ old_min, old_max = old_range
127
+ new_min, new_max = new_range
128
+ x -= old_min
129
+ x *= (new_max - new_min) / (old_max - old_min)
130
+ x += new_min
131
+ if clamp:
132
+ x = x.clamp(new_min, new_max)
133
+ return x
134
+
135
+ def get_time_embedding(timestep):
136
+ # Shape: (160,)
137
+ freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
138
+ # Shape: (1, 160)
139
+ x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
140
+ # Shape: (1, 160 * 2)
141
+ return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
sd_gradio.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from pathlib import Path
4
+ from transformers import CLIPTokenizer
5
+ import torch
6
+ from nanograd.models.stable_diffusion import model_loader, pipeline
7
+
8
+ DEVICE = "cpu"
9
+ ALLOW_CUDA = False
10
+ ALLOW_MPS = False
11
+
12
+ if torch.cuda.is_available() and ALLOW_CUDA:
13
+ DEVICE = "cuda"
14
+ elif torch.backends.mps.is_available() and ALLOW_MPS:
15
+ DEVICE = "mps"
16
+ print(f"Using device: {DEVICE}")
17
+
18
+ tokenizer_vocab_path = Path("C:\\Users\\Esmail\\Desktop\\nanograd\\nanograd\\models\\stable_diffusion\\sd_data\\tokenizer_vocab.json")
19
+ tokenizer_merges_path = Path("C:\\Users\\Esmail\\Desktop\\nanograd\\nanograd\\models\\stable_diffusion\\sd_data\\tokenizer_merges.txt")
20
+ model_file = Path("C:\\Users\\Esmail\\Desktop\\nanograd\\nanograd\\models\\stable_diffusion\\sd_data\\v1-5-pruned-emaonly.ckpt")
21
+
22
+ tokenizer = CLIPTokenizer(str(tokenizer_vocab_path), merges_file=str(tokenizer_merges_path))
23
+ models = model_loader.preload_models_from_standard_weights(str(model_file), DEVICE)
24
+
25
+ def generate_image(prompt, cfg_scale, num_inference_steps, sampler):
26
+ uncond_prompt = ""
27
+ do_cfg = True
28
+ input_image = None
29
+ strength = 0.9
30
+ seed = 42
31
+
32
+ output_image = pipeline.generate(
33
+ prompt=prompt,
34
+ uncond_prompt=uncond_prompt,
35
+ input_image=input_image,
36
+ strength=strength,
37
+ do_cfg=do_cfg,
38
+ cfg_scale=cfg_scale,
39
+ sampler_name=sampler,
40
+ n_inference_steps=num_inference_steps,
41
+ seed=seed,
42
+ models=models,
43
+ device=DEVICE,
44
+ idle_device="cpu",
45
+ tokenizer=tokenizer,
46
+ )
47
+
48
+ output_image = Image.fromarray(output_image)
49
+ return output_image
50
+
51
+ # Gradio interface
52
+ def gradio_interface():
53
+ with gr.Blocks() as demo:
54
+ with gr.Row():
55
+ with gr.Column(scale=2):
56
+ prompt_input = gr.Textbox(label="Prompt", placeholder="A cat stretching on the floor, highly detailed, ultra sharp, cinematic, 100mm lens, 8k resolution")
57
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=1)
58
+ num_inference_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=100, value=20, step=5)
59
+ sampler = gr.Radio(label="Sampling Method", choices=["ddpm", "Euler a", "Euler", "LMS", "Heun", "DPM2 a", "PLMS"], value="ddpm")
60
+ generate_btn = gr.Button("Generate", variant="primary")
61
+ with gr.Column(scale=2):
62
+ output_image = gr.Image(label="Output", show_label=False, height=512, width=512)
63
+
64
+ generate_btn.click(fn=generate_image, inputs=[prompt_input, cfg_scale, num_inference_steps, sampler], outputs=output_image)
65
+
66
+ demo.launch()
67
+
68
+ if __name__ == "__main__":
69
+ gradio_interface()
sd_inference.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nanograd.models.stable_diffusion import model_loader
2
+ from nanograd.models.stable_diffusion import pipeline
3
+
4
+ from PIL import Image
5
+ from pathlib import Path
6
+ from transformers import CLIPTokenizer
7
+ import torch
8
+
9
+ DEVICE = "cpu"
10
+
11
+ ALLOW_CUDA = False
12
+ ALLOW_MPS = False
13
+
14
+ if torch.cuda.is_available() and ALLOW_CUDA:
15
+ DEVICE = "cuda"
16
+ elif (torch.has_mps or torch.backends.mps.is_available()) and ALLOW_MPS:
17
+ DEVICE = "mps"
18
+ print(f"Using device: {DEVICE}")
19
+
20
+ tokenizer = CLIPTokenizer("nanograd\models\stable_diffusion\sd_data\\tokenizer_vocab.json", merges_file="nanograd\models\stable_diffusion\sd_data\\tokenizer_merges.txt")
21
+ model_file = "nanograd\models\stable_diffusion\sd_data\\v1-5-pruned-emaonly.ckpt"
22
+ models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)
23
+
24
+ ## TEXT TO IMAGE
25
+
26
+ prompt = input("Enter your prompt: ")
27
+ # prompt = "A cat stretching on the floor, highly detailed, ultra sharp, cinematic, 100mm lens, 8k resolution."
28
+ uncond_prompt = ""
29
+ do_cfg = True
30
+ cfg_scale = 8 # min: 1, max: 14
31
+
32
+ ## IMAGE TO IMAGE
33
+
34
+ input_image = None
35
+ # Comment to disable image to image
36
+ # image_path = "../images/dog.jpg"
37
+ # input_image = Image.open(image_path)
38
+ # Higher values means more noise will be added to the input image, so the result will further from the input image.
39
+ strength = 0.9
40
+
41
+ ## SAMPLER
42
+
43
+ sampler = "ddpm"
44
+ num_inference_steps = 50
45
+ seed = 42
46
+
47
+
48
+ def run():
49
+ output_image = pipeline.generate(
50
+ prompt=prompt,
51
+ uncond_prompt=uncond_prompt,
52
+ input_image=input_image,
53
+ strength=strength,
54
+ do_cfg=do_cfg,
55
+ cfg_scale=cfg_scale,
56
+ sampler_name=sampler,
57
+ n_inference_steps=num_inference_steps,
58
+ seed=seed,
59
+ models=models,
60
+ device=DEVICE,
61
+ idle_device="cpu",
62
+ tokenizer=tokenizer,
63
+ )
64
+
65
+ output_image = Image.fromarray(output_image)
66
+ output_path = "nanograd\models\stable_diffusion\output\\c.png"
67
+ output_image.save(output_path)
68
+ print(f"Image saved as {output_path}")
69
+
70
+ if __name__ == "__main__":
71
+ run()
72
+