import torch from torch import nn from torch.nn import functional as F import numpy as np import math from PIL import Image import streamlit as st device = 'cpu' def get_time_embedding(timestep): # Shape: (80,) freqs = torch.pow(10000, -torch.arange(start=0, end=40, dtype=torch.float32) / 80) # Shape: (1, 80) x = torch.tensor(timestep, dtype=torch.float32)[:, None] * freqs[None] # Shape: (1, 80 * 2) return torch.cat([torch.cos(x), torch.sin(x)], dim=-1) class TimeEmbedding(nn.Module): def __init__(self, n_embd): super().__init__() self.linear_1 = nn.Linear(n_embd, 4 * n_embd) self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd) def forward(self, x): # x: (1, 160) # (1, 160) -> (1, 640) x = self.linear_1(x) # (1, 640) -> (1, 640) x = F.silu(x) # (1, 640) -> (1, 640) x = self.linear_2(x) return x class UNET_ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, n_time = 320): super().__init__() self.groupnorm_feature = nn.GroupNorm(16, in_channels) self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) self.linear_time = nn.Linear(n_time, out_channels) self.groupnorm_merged = nn.GroupNorm(16, out_channels) self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) if in_channels == out_channels: self.residual_layer = nn.Identity() else: self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) def forward(self, feature, time): # feature: (Batch_Size, In_Channels, Height, Width) # time: (1, 640) residue = feature # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width) feature = self.groupnorm_feature(feature) # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width) feature = F.silu(feature) # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width) feature = self.conv_feature(feature) # (1, 640) -> (1, 640) time = F.silu(time) # (1, 640) -> (1, Out_Channels) time = self.linear_time(time) # Add width and height dimension to time. # (Batch_Size, Out_Channels, Height, Width) + (1, Out_Channels, 1, 1) -> (Batch_Size, Out_Channels, Height, Width) merged = feature + time.unsqueeze(-1).unsqueeze(-1) # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width) merged = self.groupnorm_merged(merged) # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width) merged = F.silu(merged) # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width) merged = self.conv_merged(merged) # (Batch_Size, Out_Channels, Height, Width) + (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width) return merged + self.residual_layer(residue) class SelfAttention(nn.Module): def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True): super().__init__() # This combines the Wq, Wk and Wv matrices into one matrix self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias) # This one represents the Wo matrix self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias) self.n_heads = n_heads self.d_head = d_embed // n_heads def forward(self, x): # x: # (Batch_Size, Seq_Len, Dim) # (Batch_Size, Seq_Len, Dim) input_shape = x.shape # (Batch_Size, Seq_Len, Dim) batch_size, sequence_length, d_embed = input_shape # (Batch_Size, Seq_Len, H, Dim / H) interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head) # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim) q, k, v = self.in_proj(x).chunk(3, dim=-1) # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H) q = q.view(interim_shape).transpose(1, 2) k = k.view(interim_shape).transpose(1, 2) v = v.view(interim_shape).transpose(1, 2) # (Batch_Size, H, Seq_Len, Dim) @ (Batch_Size, H, Dim, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len) weight = q @ k.transpose(-1, -2) # Divide by d_k (Dim / H). # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len) weight /= math.sqrt(self.d_head) # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len) weight = F.softmax(weight, dim=-1) # (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H) output = weight @ v # (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H) output = output.transpose(1, 2) # (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim) output = output.reshape(input_shape) # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim) output = self.out_proj(output) # (Batch_Size, Seq_Len, Dim) return output class UNET_AttentionBlock(nn.Module): def __init__(self, n_head: int, n_embd: int): super().__init__() channels = n_head * n_embd self.groupnorm = nn.GroupNorm(16, channels, eps=1e-6) self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0) self.layernorm_1 = nn.LayerNorm(channels) self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False) self.layernorm_3 = nn.LayerNorm(channels) self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2) self.linear_geglu_2 = nn.Linear(4 * channels, channels) self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0) def forward(self, x): # x: (Batch_Size, Features, Height, Width) residue_long = x # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width) x = self.groupnorm(x) # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width) x = self.conv_input(x) n, c, h, w = x.shape # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width) x = x.view((n, c, h * w)) # (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features) x = x.transpose(-1, -2) # Normalization + Self-Attention with skip connection # (Batch_Size, Height * Width, Features) residue_short = x # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features) x = self.layernorm_1(x) # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features) x = self.attention_1(x) # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features) x += residue_short # Normalization + FFN with GeGLU and skip connection residue_short = x # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features) x = self.layernorm_3(x) # Use of geglu taken from Stable Diffusion github # (Batch_Size, Height * Width, Features) -> two tensors of shape (Batch_Size, Height * Width, Features * 4) x, gate = self.linear_geglu_1(x).chunk(2, dim=-1) # Element-wise product: (Batch_Size, Height * Width, Features * 4) * (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features * 4) x = x * F.gelu(gate) # (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features) x = self.linear_geglu_2(x) # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features) x += residue_short # (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width) x = x.transpose(-1, -2) # (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width) x = x.view((n, c, h, w)) # Final skip connection between initial input and output of the block # (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width) return self.conv_output(x) + residue_long class Upsample(nn.Module): def __init__(self, channels): super().__init__() self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) def forward(self, x): # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * 2, Width * 2) x = F.interpolate(x, scale_factor=2, mode='nearest') return self.conv(x) class SwitchSequential(nn.Sequential): def forward(self, x, time): for layer in self: if isinstance(layer, UNET_AttentionBlock): x = layer(x) elif isinstance(layer, UNET_ResidualBlock): x = layer(x, time) else: x = layer(x) return x class UNET(nn.Module): def __init__(self): super().__init__() self.encoders = nn.ModuleList([ # (Batch_Size,1 , Height , Width) -> (Batch_Size, 80, Height , Width ) SwitchSequential(nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)), SwitchSequential(UNET_ResidualBlock(64, 64), UNET_AttentionBlock(4, 16)), SwitchSequential(UNET_ResidualBlock(64, 64), UNET_AttentionBlock(4, 16)), # (Batch_Size, 80, Height , Width) -> (Batch_Size, 160, Height / 2, Width / 2) SwitchSequential(nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)), SwitchSequential(UNET_ResidualBlock(64, 128), UNET_AttentionBlock(4, 32)), # (Batch_Size, 160, Height / 2, Width / 2) -> (Batch_Size, 160, Height / 2, Width / 2) -> (Batch_Size, 160, Height / 2, Width / 2) SwitchSequential(UNET_ResidualBlock(128, 128), UNET_AttentionBlock(4, 32)), # (Batch_Size, 160, Height / 2, Width / 2) -> (Batch_Size, 320, Height / 4, Width / 4) SwitchSequential(nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)), SwitchSequential(UNET_ResidualBlock(128, 128)), SwitchSequential(UNET_ResidualBlock(128, 128)), ]) self.bottleneck = SwitchSequential( # # (Batch_Size, 320, Height / 4, Width / 4) -> (Batch_Size, 320, Height / 4, Width / 4) UNET_ResidualBlock(128, 128), # # (Batch_Size, 320, Height / 4, Width / 4) -> (Batch_Size, 320, Height / 4, Width / 4) UNET_AttentionBlock(4, 32), # # (Batch_Size, 320, Height / 4, Width / 4) -> (Batch_Size, 320, Height / 4, Width / 4) UNET_ResidualBlock(128, 128), ) self.decoders = nn.ModuleList([ SwitchSequential(UNET_ResidualBlock(256, 128)), SwitchSequential(UNET_ResidualBlock(256, 128)), SwitchSequential(UNET_ResidualBlock(256, 128), Upsample(128)), SwitchSequential(UNET_ResidualBlock(256, 128), UNET_AttentionBlock(4, 32)), SwitchSequential(UNET_ResidualBlock(256, 128), UNET_AttentionBlock(4, 32)), SwitchSequential(UNET_ResidualBlock(192, 128), UNET_AttentionBlock(4, 32), Upsample(128)), SwitchSequential(UNET_ResidualBlock(192, 64), UNET_AttentionBlock(4, 16)), SwitchSequential(UNET_ResidualBlock(128, 64), UNET_AttentionBlock(4, 16)), SwitchSequential(UNET_ResidualBlock(128, 64), UNET_AttentionBlock(4, 16)), ]) def forward(self, x, time): # x: (Batch_Size, 1, Height , Width) # time: (1, 640) skip_connections = [] for layers in self.encoders: x = layers(x, time) skip_connections.append(x) x = self.bottleneck(x, time) for layers in self.decoders: x = torch.cat((x, skip_connections.pop()), dim=1) x = layers(x, time) return x class UNET_OutputLayer(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.groupnorm = nn.GroupNorm(16, in_channels) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) def forward(self, x): # x: (Batch_Size, 80, Height , Width ) # (Batch_Size, 80, Height, Width ) -> (Batch_Size, 80, Height , Width ) x = self.groupnorm(x) # (Batch_Size, 80, Height, Width ) -> (Batch_Size, 80, Height , Width ) x = F.silu(x) # (Batch_Size, 80, Height, Width ) -> (Batch_Size, 1, Height , Width ) x = self.conv(x) # (Batch_Size, 1, Height , Width) return x class Diffusion(nn.Module): def __init__(self): super().__init__() self.time_embedding = TimeEmbedding(80) self.unet = UNET() self.final = UNET_OutputLayer(64, 1) def forward(self, x, time): # x: (Batch_Size, 1, Height , Width ) # time: (1, 160) # (1, 160) -> (1, 640) time = self.time_embedding(time) # (Batch, 1, Height , Width ) -> (Batch, 80, Height, Width ) output = self.unet(x, time) # (Batch, 80, Height , Width ) -> (Batch, 1, Height , Width ) output = self.final(output) # (Batch, 1, Height , Width) return output class DDPMSampler: def __init__(self, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120): # For the naming conventions, refer to the DDPM paper self.betas = torch.linspace(beta_start , beta_end, num_training_steps, dtype=torch.float32) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) generator = torch.Generator(device=device) generator.seed() self.generator = generator self.num_train_timesteps = num_training_steps self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy()) def set_inference_timesteps(self, num_inference_steps=50): self.num_inference_steps = num_inference_steps step_ratio = self.num_train_timesteps // self.num_inference_steps timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.timesteps = torch.from_numpy(timesteps) def _get_previous_timestep(self, timestep: int) -> int: prev_t = timestep - self.num_train_timesteps // self.num_inference_steps return prev_t def _get_variance(self, timestep: int) -> torch.Tensor: prev_t = self._get_previous_timestep(timestep) alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t # we always take the log of variance, so clamp it to ensure it's not 0 variance = torch.clamp(variance, min=1e-20) return variance def step(self, timestep: int, reshaped_image: torch.Tensor, model_output: torch.Tensor): t = timestep prev_t = self._get_previous_timestep(t) # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called pred_original_sample = (reshaped_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) # 3. Compute coefficients for pred_original_sample x_0 and current sample x_t pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 4. Compute predicted previous sample ยต_t pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * reshaped_image # 5. Add noise variance = 0 if t > 0: device = model_output.device noise = torch.randn(model_output.shape, device=device, dtype=model_output.dtype) # Computing the variance variance = (self._get_variance(t) ** 0.5) * noise # sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1) # the variable "variance" is already multiplied by the noise N(0, 1) pred_prev_sample = pred_prev_sample + variance return pred_prev_sample def add_noise( self, original_samples: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples, noise def generate( diffusion_model, n_inference_steps=50, ): with torch.no_grad(): sampler = DDPMSampler() sampler.set_inference_timesteps(n_inference_steps) image_shape = (1, 1, 28, 28) image = torch.randn(image_shape, device=device) diffusion_model.to(device) timesteps = sampler.timesteps for i, timestep in enumerate(timesteps): # (1, 160) time_embedding = get_time_embedding(torch.tensor(timestep).view(1,)).to(device) # (Batch_Size, 1, Height, Width) model_input = image # model_output is the predicted noise # (Batch_Size, 1, Height, Width) -> (Batch_Size, 1, Height, Width) model_output = diffusion_model(model_input, time_embedding) # (Batch_Size, 1, Height, Width) -> (Batch_Size, 1, Height, Width) image = sampler.step(timestep, image, model_output) # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Height, Width, Channel) image = image.permute(0, 2, 3, 1) image = image.detach().cpu() return image[0] def gen_img(): diffusion_model = Diffusion().to(device) checkpoint = torch.load("diffusion.tar", map_location=torch.device('cpu')) diffusion_model.load_state_dict(checkpoint) images = [] for _ in range(12): image = generate(diffusion_model) image = image.squeeze() image = image.numpy() image = image.astype(np.float32) image = (image - image.min()) / (image.max() - image.min()) image = (image * 255).astype(np.uint8) image = Image.fromarray(image, mode='L') images.append(image) return images st.title("Unconditional Image Generation") if st.button('Generate Image'): with st.spinner('Generating an image...'): generated_images = gen_img() cols = st.columns(4) for i in range(3): for j in range(4): with cols[j]: st.image(generated_images[i * 4 + j], caption=f"Generated Image {i * 4 + j + 1}") st.markdown('
', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True)