Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
class DoubleConv(nn.Module): | |
def __init__(self, in_channels, out_channels, mid_channels=None, residual=False): | |
super().__init__() | |
self.residual = residual | |
if not mid_channels: | |
mid_channels = out_channels | |
self.double_conv = nn.Sequential( | |
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), | |
nn.GroupNorm(1, mid_channels), | |
nn.GELU(), | |
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
nn.GroupNorm(1, out_channels), | |
) | |
def forward(self, x): | |
if self.residual: | |
return F.gelu(x + self.double_conv(x)) | |
else: | |
return self.double_conv(x) | |
class Down(nn.Module): | |
def __init__(self, in_channels, out_channels, emb_dim=256): | |
super().__init__() | |
self.maxpool_conv = nn.Sequential( | |
nn.MaxPool2d(2), | |
DoubleConv(in_channels, in_channels, residual=True), | |
DoubleConv(in_channels, out_channels), | |
) | |
self.emb_layer = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear( | |
emb_dim, | |
out_channels | |
), | |
) | |
def forward(self, x, t): | |
x = self.maxpool_conv(x) | |
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) | |
return x + emb | |
class Up(nn.Module): | |
def __init__(self, in_channels, out_channels, emb_dim=256): | |
super().__init__() | |
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) | |
self.conv = nn.Sequential( | |
DoubleConv(in_channels, in_channels, residual=True), | |
DoubleConv(in_channels, out_channels, in_channels // 2), | |
) | |
self.emb_layer = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear( | |
emb_dim, | |
out_channels | |
), | |
) | |
def forward(self, x, skip_x, t): | |
x = self.up(x) | |
x = torch.cat([skip_x, x], dim=1) | |
x = self.conv(x) | |
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) | |
return x + emb | |
class UNet(nn.Module): | |
def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"): | |
super().__init__() | |
self.device = device | |
self.time_dim = time_dim | |
self.inc = DoubleConv(c_in, 64) | |
self.down1 = Down(64, 128) | |
self.down2 = Down(128, 256) | |
self.down3 = Down(256, 256) | |
self.bot1 = DoubleConv(256, 512) | |
self.bot2 = DoubleConv(512, 512) | |
self.bot3 = DoubleConv(512, 256) | |
self.up1 = Up(512, 128) | |
self.up2 = Up(256, 64) | |
self.up3 = Up(128, 64) | |
self.outc = nn.Conv2d(64, c_out, kernel_size=1) | |
def positional_encoding(self, t, channels): | |
inv_freq = 1.0 / ( | |
10000 | |
** (torch.arange(0, channels, 2, device=self.device).float() / channels) | |
) | |
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq) | |
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq) | |
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1) | |
return pos_enc | |
def forward(self, image, t): | |
t = t.unsqueeze(-1).type(torch.float) | |
t = self.positional_encoding(t, self.time_dim) | |
x1 = self.inc(image) | |
x2 = self.down1(x1, t) | |
x3 = self.down2(x2, t) | |
x4 = self.down3(x3, t) | |
x4 = self.bot1(x4) | |
# x4 = self.bot2(x4) | |
x4 = self.bot3(x4) | |
x = self.up1(x4, x3, t) | |
x = self.up2(x, x2, t) | |
x = self.up3(x, x1, t) | |
output = self.outc(x) | |
return output | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = UNet(device = device).to(device) | |
model.load_state_dict(torch.load('Model_Saved_States/diffusion_64.pth')) | |
img_size = 64 | |
class Diffusion(): | |
def __init__(self, time_steps = 500, beta_start = 0.0001, beta_stop = 0.02, image_size = 64, device = device): | |
self.time_steps = time_steps | |
self.beta_start = beta_start | |
self.beta_stop = beta_stop | |
self.img_size = image_size | |
self.device = device | |
self.beta = self.beta_schedule() | |
self.beta = self.beta.to(device) | |
self.alpha = 1 - self.beta | |
self.alpha = self.alpha.to(device) | |
self.alpha_hat = torch.cumprod(self.alpha, dim = 0).to(device) | |
def beta_schedule(self): | |
return torch.linspace(self.beta_start, self.beta_stop, self.time_steps) | |
def noise_images(self, images, t): | |
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None,] | |
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None,] | |
noises = torch.randn_like(images) | |
noised_images = sqrt_alpha_hat * images + sqrt_one_minus_alpha_hat * noises | |
return noised_images, noises | |
def random_timesteps(self, n): | |
return torch.randint(low=1, high=self.time_steps, size=(n,)) | |
def generate_samples(self, model, n): | |
with torch.no_grad(): | |
x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device) | |
for i in range(self.time_steps - 1, 1, -1): | |
t = (torch.ones(n) * i).long().to(self.device) | |
predicted_noise = model(x, t) | |
alpha = self.alpha[t][:, None, None, None] | |
alpha_hat = self.alpha_hat[t][:, None, None, None] | |
beta = self.beta[t][:, None, None, None] | |
if i > 1: | |
noise = torch.randn_like(x) | |
else: | |
noise = torch.zeros_like(x) | |
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise | |
return (x[0].cpu().numpy().transpose(1, 2, 0) / 255) | |
#show_images | |
diffusion = Diffusion() | |
import numpy as np | |
def greet(n): | |
image = diffusion.generate_samples(model, n = 1) | |
image = (np.clip(image * 255, -1, 1) + 1) / 2 | |
plt.imshow(image) | |
return image | |
iface = gr.Interface(fn=greet, inputs="number", outputs="image") | |
iface.launch(share = True) |