import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None, residual=False): super().__init__() self.residual = residual if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.GroupNorm(1, mid_channels), nn.GELU(), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.GroupNorm(1, out_channels), ) def forward(self, x): if self.residual: return F.gelu(x + self.double_conv(x)) else: return self.double_conv(x) class Down(nn.Module): def __init__(self, in_channels, out_channels, emb_dim=256): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, in_channels, residual=True), DoubleConv(in_channels, out_channels), ) self.emb_layer = nn.Sequential( nn.SiLU(), nn.Linear( emb_dim, out_channels ), ) def forward(self, x, t): x = self.maxpool_conv(x) emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) return x + emb class Up(nn.Module): def __init__(self, in_channels, out_channels, emb_dim=256): super().__init__() self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.conv = nn.Sequential( DoubleConv(in_channels, in_channels, residual=True), DoubleConv(in_channels, out_channels, in_channels // 2), ) self.emb_layer = nn.Sequential( nn.SiLU(), nn.Linear( emb_dim, out_channels ), ) def forward(self, x, skip_x, t): x = self.up(x) x = torch.cat([skip_x, x], dim=1) x = self.conv(x) emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) return x + emb class UNet(nn.Module): def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"): super().__init__() self.device = device self.time_dim = time_dim self.inc = DoubleConv(c_in, 64) self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 256) self.bot1 = DoubleConv(256, 512) self.bot2 = DoubleConv(512, 512) self.bot3 = DoubleConv(512, 256) self.up1 = Up(512, 128) self.up2 = Up(256, 64) self.up3 = Up(128, 64) self.outc = nn.Conv2d(64, c_out, kernel_size=1) def positional_encoding(self, t, channels): inv_freq = 1.0 / ( 10000 ** (torch.arange(0, channels, 2, device=self.device).float() / channels) ) pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq) pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq) pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1) return pos_enc def forward(self, image, t): t = t.unsqueeze(-1).type(torch.float) t = self.positional_encoding(t, self.time_dim) x1 = self.inc(image) x2 = self.down1(x1, t) x3 = self.down2(x2, t) x4 = self.down3(x3, t) x4 = self.bot1(x4) # x4 = self.bot2(x4) x4 = self.bot3(x4) x = self.up1(x4, x3, t) x = self.up2(x, x2, t) x = self.up3(x, x1, t) output = self.outc(x) return output device = 'cuda' if torch.cuda.is_available() else 'cpu' model = UNet(device = device).to(device) model.load_state_dict(torch.load('Model_Saved_States/diffusion_64.pth')) img_size = 64 class Diffusion(): def __init__(self, time_steps = 500, beta_start = 0.0001, beta_stop = 0.02, image_size = 64, device = device): self.time_steps = time_steps self.beta_start = beta_start self.beta_stop = beta_stop self.img_size = image_size self.device = device self.beta = self.beta_schedule() self.beta = self.beta.to(device) self.alpha = 1 - self.beta self.alpha = self.alpha.to(device) self.alpha_hat = torch.cumprod(self.alpha, dim = 0).to(device) def beta_schedule(self): return torch.linspace(self.beta_start, self.beta_stop, self.time_steps) def noise_images(self, images, t): sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None,] sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None,] noises = torch.randn_like(images) noised_images = sqrt_alpha_hat * images + sqrt_one_minus_alpha_hat * noises return noised_images, noises def random_timesteps(self, n): return torch.randint(low=1, high=self.time_steps, size=(n,)) def generate_samples(self, model, n): with torch.no_grad(): x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device) for i in range(self.time_steps - 1, 1, -1): t = (torch.ones(n) * i).long().to(self.device) predicted_noise = model(x, t) alpha = self.alpha[t][:, None, None, None] alpha_hat = self.alpha_hat[t][:, None, None, None] beta = self.beta[t][:, None, None, None] if i > 1: noise = torch.randn_like(x) else: noise = torch.zeros_like(x) x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise return (x[0].cpu().numpy().transpose(1, 2, 0) / 255) #show_images diffusion = Diffusion() import numpy as np def greet(n): image = diffusion.generate_samples(model, n = 1) image = (np.clip(image * 255, -1, 1) + 1) / 2 plt.imshow(image) return image iface = gr.Interface(fn=greet, inputs="number", outputs="image") iface.launch(share = True)