zen21's picture
Update app.py
645ab7f
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
super().__init__()
self.residual = residual
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(1, mid_channels),
nn.GELU(),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(1, out_channels),
)
def forward(self, x):
if self.residual:
return F.gelu(x + self.double_conv(x))
else:
return self.double_conv(x)
class Down(nn.Module):
def __init__(self, in_channels, out_channels, emb_dim=256):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, in_channels, residual=True),
DoubleConv(in_channels, out_channels),
)
self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_dim,
out_channels
),
)
def forward(self, x, t):
x = self.maxpool_conv(x)
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb
class Up(nn.Module):
def __init__(self, in_channels, out_channels, emb_dim=256):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = nn.Sequential(
DoubleConv(in_channels, in_channels, residual=True),
DoubleConv(in_channels, out_channels, in_channels // 2),
)
self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_dim,
out_channels
),
)
def forward(self, x, skip_x, t):
x = self.up(x)
x = torch.cat([skip_x, x], dim=1)
x = self.conv(x)
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb
class UNet(nn.Module):
def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"):
super().__init__()
self.device = device
self.time_dim = time_dim
self.inc = DoubleConv(c_in, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 256)
self.bot1 = DoubleConv(256, 512)
self.bot2 = DoubleConv(512, 512)
self.bot3 = DoubleConv(512, 256)
self.up1 = Up(512, 128)
self.up2 = Up(256, 64)
self.up3 = Up(128, 64)
self.outc = nn.Conv2d(64, c_out, kernel_size=1)
def positional_encoding(self, t, channels):
inv_freq = 1.0 / (
10000
** (torch.arange(0, channels, 2, device=self.device).float() / channels)
)
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
return pos_enc
def forward(self, image, t):
t = t.unsqueeze(-1).type(torch.float)
t = self.positional_encoding(t, self.time_dim)
x1 = self.inc(image)
x2 = self.down1(x1, t)
x3 = self.down2(x2, t)
x4 = self.down3(x3, t)
x4 = self.bot1(x4)
# x4 = self.bot2(x4)
x4 = self.bot3(x4)
x = self.up1(x4, x3, t)
x = self.up2(x, x2, t)
x = self.up3(x, x1, t)
output = self.outc(x)
return output
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = UNet(device = device).to(device)
model.load_state_dict(torch.load('Model_Saved_States/diffusion_64.pth', map_location=torch.device(device)))
img_size = 64
class Diffusion():
def __init__(self, time_steps = 500, beta_start = 0.0001, beta_stop = 0.02, image_size = 64, device = device):
self.time_steps = time_steps
self.beta_start = beta_start
self.beta_stop = beta_stop
self.img_size = image_size
self.device = device
self.beta = self.beta_schedule()
self.beta = self.beta.to(device)
self.alpha = 1 - self.beta
self.alpha = self.alpha.to(device)
self.alpha_hat = torch.cumprod(self.alpha, dim = 0).to(device)
def beta_schedule(self):
return torch.linspace(self.beta_start, self.beta_stop, self.time_steps)
def noise_images(self, images, t):
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None,]
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None,]
noises = torch.randn_like(images)
noised_images = sqrt_alpha_hat * images + sqrt_one_minus_alpha_hat * noises
return noised_images, noises
def random_timesteps(self, n):
return torch.randint(low=1, high=self.time_steps, size=(n,))
def generate_samples(self, model, n):
with torch.no_grad():
x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
for i in range(self.time_steps - 1, 1, -1):
t = (torch.ones(n) * i).long().to(self.device)
predicted_noise = model(x, t)
alpha = self.alpha[t][:, None, None, None]
alpha_hat = self.alpha_hat[t][:, None, None, None]
beta = self.beta[t][:, None, None, None]
if i > 1:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
return (x[0].cpu().numpy().transpose(1, 2, 0) / 255)
#show_images
diffusion = Diffusion()
import numpy as np
def greet(n):
image = diffusion.generate_samples(model, n = 1)
image = (np.clip(image * 255, -1, 1) + 1) / 2
plt.imshow(image)
return image
iface = gr.Interface(fn=greet, inputs="number", outputs="image")
iface.launch()