|
|
import streamlit as st |
|
|
from PIL import Image, ImageOps |
|
|
import torch |
|
|
from matplotlib.image import imread |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import math |
|
|
import torch.nn.functional as F |
|
|
from tqdm.auto import tqdm |
|
|
from torchvision import transforms |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from torch import nn |
|
|
img_size = 64 |
|
|
BATCH_SIZE = 64 |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, in_ch, out_ch, time_emb_dim, up=False): |
|
|
super().__init__() |
|
|
self.time_mlp = nn.Linear(time_emb_dim, out_ch) |
|
|
if up: |
|
|
self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1) |
|
|
self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) |
|
|
self.Upsample = nn.Upsample(scale_factor = 2, mode ='bilinear') |
|
|
|
|
|
else: |
|
|
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) |
|
|
self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1) |
|
|
self.maxpool = nn.MaxPool2d(4, 2, 1) |
|
|
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) |
|
|
self.bnorm1 = nn.BatchNorm2d(out_ch) |
|
|
self.bnorm2 = nn.BatchNorm2d(out_ch) |
|
|
self.silu = nn.SiLU() |
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
def forward(self, x, t, ): |
|
|
|
|
|
h = (self.silu(self.bnorm1(self.conv1(x)))) |
|
|
|
|
|
time_emb = self.relu(self.time_mlp(t)) |
|
|
|
|
|
time_emb = time_emb[(..., ) + (None, ) * 2] |
|
|
|
|
|
h = h + time_emb |
|
|
|
|
|
h = (self.silu(self.bnorm2(self.conv2(h)))) |
|
|
|
|
|
return self.transform(h) |
|
|
|
|
|
|
|
|
class SinusoidalPositionEmbeddings(nn.Module): |
|
|
def __init__(self, dim): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
|
|
|
def forward(self, time): |
|
|
device = time.device |
|
|
half_dim = self.dim // 2 |
|
|
embeddings = math.log(10000) / (half_dim - 1) |
|
|
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) |
|
|
embeddings = time[:, None] * embeddings[None, :] |
|
|
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
|
|
|
class SimpleUnet(nn.Module): |
|
|
""" |
|
|
A simplified variant of the Unet architecture. |
|
|
""" |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
image_channels = 3 |
|
|
down_channels = (32, 64, 128, 256, 512) |
|
|
up_channels = (512, 256, 128, 64, 32) |
|
|
out_dim = 3 |
|
|
time_emb_dim = 32 |
|
|
|
|
|
|
|
|
self.time_mlp = nn.Sequential( |
|
|
SinusoidalPositionEmbeddings(time_emb_dim), |
|
|
nn.Linear(time_emb_dim, time_emb_dim), |
|
|
nn.ReLU() |
|
|
) |
|
|
|
|
|
|
|
|
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1) |
|
|
|
|
|
|
|
|
self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \ |
|
|
time_emb_dim) \ |
|
|
for i in range(len(down_channels)-1)]) |
|
|
|
|
|
self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \ |
|
|
time_emb_dim, up=True) \ |
|
|
for i in range(len(up_channels)-1)]) |
|
|
|
|
|
|
|
|
self.output = nn.Conv2d(up_channels[-1], out_dim, 1) |
|
|
|
|
|
def forward(self, x, timestep): |
|
|
|
|
|
t = self.time_mlp(timestep) |
|
|
|
|
|
x = self.conv0(x) |
|
|
|
|
|
residual_inputs = [] |
|
|
for down in self.downs: |
|
|
x = down(x, t) |
|
|
residual_inputs.append(x) |
|
|
for up in self.ups: |
|
|
residual_x = residual_inputs.pop() |
|
|
|
|
|
x = torch.cat((x, residual_x), dim=1) |
|
|
x = up(x, t) |
|
|
return self.output(x) |
|
|
|
|
|
model = SimpleUnet() |
|
|
|
|
|
|
|
|
def linear_beta_schedule(timesteps): |
|
|
beta_start = 0.0001 |
|
|
beta_end = 0.02 |
|
|
return torch.linspace(beta_start, beta_end, timesteps) |
|
|
|
|
|
timesteps= 300 |
|
|
betas = linear_beta_schedule(timesteps=timesteps) |
|
|
|
|
|
alphas = 1. - betas |
|
|
alphas_cumprod = torch.cumprod(alphas, axis=0) |
|
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) |
|
|
sqrt_recip_alphas = torch.sqrt(1.0 / alphas) |
|
|
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) |
|
|
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) |
|
|
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) |
|
|
|
|
|
|
|
|
def extract(a, t, x_shape): |
|
|
batch_size = t.shape[0] |
|
|
out = a.gather(-1, t.cpu()) |
|
|
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) |
|
|
|
|
|
@torch.no_grad() |
|
|
def p_sample(model, x, t, t_index): |
|
|
betas_t = extract(betas, t, x.shape) |
|
|
sqrt_one_minus_alphas_cumprod_t = extract( |
|
|
sqrt_one_minus_alphas_cumprod, t, x.shape |
|
|
) |
|
|
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape) |
|
|
|
|
|
|
|
|
|
|
|
model_mean = sqrt_recip_alphas_t * ( |
|
|
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t |
|
|
) |
|
|
|
|
|
if t_index == 0: |
|
|
return model_mean |
|
|
else: |
|
|
posterior_variance_t = extract(posterior_variance, t, x.shape) |
|
|
noise = torch.randn_like(x) |
|
|
|
|
|
return model_mean + torch.sqrt(posterior_variance_t) * noise |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def p_sample_loop(model, shape): |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
b = shape[0] |
|
|
|
|
|
img = torch.randn(shape, device=device) |
|
|
imgs = [] |
|
|
|
|
|
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=1): |
|
|
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), 3) |
|
|
imgs.append(img.cpu().numpy()) |
|
|
return imgs |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample(model, image_size, batch_size=16, channels=3): |
|
|
return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size)) |
|
|
|
|
|
|
|
|
|
|
|
model = SimpleUnet() |
|
|
|
|
|
st.title("Generatig images using a diffusion model") |
|
|
model.load_state_dict(torch.load("new_linear_model_1090.pt", map_location=torch.device('cpu'))) |
|
|
|
|
|
|
|
|
if(st.button("Click to generate image")): |
|
|
samples = sample(model, image_size=img_size, batch_size=64, channels=3) |
|
|
for i in range(1): |
|
|
reverse_transforms = transforms.Compose([ |
|
|
transforms.Lambda(lambda t: (t + 1) / 2), |
|
|
transforms.Lambda(lambda t: t.permute(1, 2, 0)), |
|
|
transforms.Lambda(lambda t: t * 255.), |
|
|
transforms.Lambda(lambda t: t.numpy().astype(np.uint8)), |
|
|
transforms.ToPILImage(), |
|
|
]) |
|
|
img = reverse_transforms(torch.Tensor((samples[-1][i].reshape(3, img_size, img_size)))) |
|
|
|
|
|
st.image(plt.imshow(img)) |