import streamlit as st from PIL import Image, ImageOps import torch from matplotlib.image import imread import numpy as np import tensorflow as tf import math import torch.nn.functional as F from tqdm.auto import tqdm from torchvision import transforms import matplotlib.pyplot as plt from torch import nn img_size = 64 BATCH_SIZE = 64 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim, up=False): super().__init__() self.time_mlp = nn.Linear(time_emb_dim, out_ch) if up: self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1) self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) self.Upsample = nn.Upsample(scale_factor = 2, mode ='bilinear') else: self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1) self.maxpool = nn.MaxPool2d(4, 2, 1) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) self.bnorm1 = nn.BatchNorm2d(out_ch) self.bnorm2 = nn.BatchNorm2d(out_ch) self.silu = nn.SiLU() self.relu = nn.ReLU() def forward(self, x, t, ): # First Conv h = (self.silu(self.bnorm1(self.conv1(x)))) # Time embedding time_emb = self.relu(self.time_mlp(t)) # Extend last 2 dimensions time_emb = time_emb[(..., ) + (None, ) * 2] # Add time channel h = h + time_emb # Second Conv h = (self.silu(self.bnorm2(self.conv2(h)))) # Down or Upsample return self.transform(h) class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) # TODO: Double check the ordering here return embeddings class SimpleUnet(nn.Module): """ A simplified variant of the Unet architecture. """ def __init__(self): super().__init__() image_channels = 3 down_channels = (32, 64, 128, 256, 512) up_channels = (512, 256, 128, 64, 32) out_dim = 3 time_emb_dim = 32 # Time embedding self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(time_emb_dim), nn.Linear(time_emb_dim, time_emb_dim), nn.ReLU() ) # Initial projection self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1) # Downsample self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \ time_emb_dim) \ for i in range(len(down_channels)-1)]) # Upsample self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \ time_emb_dim, up=True) \ for i in range(len(up_channels)-1)]) # Edit: Corrected a bug found by Jakub C (see YouTube comment) self.output = nn.Conv2d(up_channels[-1], out_dim, 1) def forward(self, x, timestep): # Embedd time t = self.time_mlp(timestep) # Initial conv x = self.conv0(x) # Unet residual_inputs = [] for down in self.downs: x = down(x, t) residual_inputs.append(x) for up in self.ups: residual_x = residual_inputs.pop() # Add residual x as additional channels x = torch.cat((x, residual_x), dim=1) x = up(x, t) return self.output(x) model = SimpleUnet() def linear_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 return torch.linspace(beta_start, beta_end, timesteps) timesteps= 300 betas = linear_beta_schedule(timesteps=timesteps) alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) sqrt_recip_alphas = torch.sqrt(1.0 / alphas) sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) def extract(a, t, x_shape): batch_size = t.shape[0] out = a.gather(-1, t.cpu()) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) @torch.no_grad() def p_sample(model, x, t, t_index): betas_t = extract(betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t = extract( sqrt_one_minus_alphas_cumprod, t, x.shape ) sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape) # Equation 11 in the paper # Use our model (noise predictor) to predict the mean model_mean = sqrt_recip_alphas_t * ( x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t ) if t_index == 0: return model_mean else: posterior_variance_t = extract(posterior_variance, t, x.shape) noise = torch.randn_like(x) # Algorithm 2 line 4: return model_mean + torch.sqrt(posterior_variance_t) * noise # Algorithm 2 but save all images: @torch.no_grad() def p_sample_loop(model, shape): device = next(model.parameters()).device b = shape[0] # start from pure noise (for each example in the batch) img = torch.randn(shape, device=device) imgs = [] for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=1): img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), 3) imgs.append(img.cpu().numpy()) return imgs @torch.no_grad() def sample(model, image_size, batch_size=16, channels=3): return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size)) model = SimpleUnet() st.title("Generatig images using a diffusion model") model.load_state_dict(torch.load("new_linear_model_1090.pt", map_location=torch.device('cpu'))) if(st.button("Click to generate image")): samples = sample(model, image_size=img_size, batch_size=64, channels=3) for i in range(1): reverse_transforms = transforms.Compose([ transforms.Lambda(lambda t: (t + 1) / 2), transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC transforms.Lambda(lambda t: t * 255.), transforms.Lambda(lambda t: t.numpy().astype(np.uint8)), transforms.ToPILImage(), ]) img = reverse_transforms(torch.Tensor((samples[-1][i].reshape(3, img_size, img_size)))) st.image(plt.imshow(img))