text-to-image-model / wrapper.py
JBlitzar
ahahahaha it works
9f5a022
import torch
import torch.nn as nn
from enum import Enum
from tqdm import trange
Schedule = Enum('Schedule', ['LINEAR', 'COSINE'])
class DiffusionManager(nn.Module):
def __init__(self, model: nn.Module, noise_steps=1000, start=0.0001, end=0.02, device="cpu", **kwargs ) -> None:
super().__init__(**kwargs)
self.model = model
self.noise_steps = noise_steps
self.start = start
self.end = end
self.device = device
self.schedule = None
self.set_schedule()
#model.set_parent(self)
def _get_schedule(self, schedule_type: Schedule = Schedule.LINEAR):
if schedule_type == Schedule.LINEAR:
return torch.linspace(self.start, self.end, self.noise_steps)
elif schedule_type == Schedule.COSINE:
# https://arxiv.org/pdf/2102.09672 page 4
#https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
#line 18
def get_alphahat_at(t):
def f(t):
s=self.start
return torch.cos((t/self.noise_steps + s)/(1+s) * torch.pi/2) ** 2
return f(t)/f(torch.zeros_like(t))
t = torch.Tensor(range(self.noise_steps))
t = 1-(get_alphahat_at(t + 1)/get_alphahat_at(t))
t = torch.minimum(t, torch.ones_like(t) * 0.999) #"In practice, we clip β_t to be no larger than 0.999 to prevent singularities at the end of the diffusion process n"
return t
def set_schedule(self, schedule: Schedule = Schedule.LINEAR):
self.schedule = self._get_schedule(schedule).to(self.device)
def get_schedule_at(self, step):
beta = self.schedule
alpha = 1 - beta
alpha_hat = torch.cumprod(alpha, dim=0)
return self._unsqueezify(beta.data[step]), self._unsqueezify(alpha.data[step]), self._unsqueezify(alpha_hat.data[step])
@staticmethod
def _unsqueezify(value):
return value.view(-1, 1, 1, 1)#.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
def noise_image(self, image, step):
image = image.to(self.device)
beta, alpha, alpha_hat = self.get_schedule_at(step)
epsilon = torch.randn_like(image)
# print(alpha_hat)
# print(alpha_hat.size())
# print(image.size())
noised_img = torch.sqrt(alpha_hat) * image + torch.sqrt(1 - alpha_hat) * epsilon
return noised_img, epsilon
def random_timesteps(self, amt=1):
return torch.randint(low=1, high=self.noise_steps, size=(amt,))
def sample(self, img_size, condition, amt=5, use_tqdm=True):
if tuple(condition.shape)[0] < amt:
condition = condition.repeat(amt, 1)
self.model.eval()
condition = condition.to(self.device)
my_trange = lambda x, y, z: trange(x,y, z, leave=False,dynamic_ncols=True)
fn = my_trange if use_tqdm else range
with torch.no_grad():
cur_img = torch.randn((amt, 3, img_size, img_size)).to(self.device)
for i in fn(self.noise_steps-1, 0, -1):
timestep = torch.ones(amt) * (i)
timestep = timestep.to(self.device)
predicted_noise = self.model(cur_img, timestep, condition)
beta, alpha, alpha_hat = self.get_schedule_at(i)
cur_img = (1/torch.sqrt(alpha))*(cur_img - (beta/torch.sqrt(1-alpha_hat))*predicted_noise)
if i > 1:
cur_img = cur_img + torch.sqrt(beta)*torch.randn_like(cur_img)
self.model.train()
return cur_img
def sample_multicond(self, img_size, condition, use_tqdm=True):
num_conditions = condition.shape[0]
amt = num_conditions
self.model.eval()
condition = condition.to(self.device)
my_trange = lambda x, y, z: trange(x, y, z, leave=False, dynamic_ncols=True)
fn = my_trange if use_tqdm else range
with torch.no_grad():
cur_img = torch.randn((amt, 3, img_size, img_size)).to(self.device)
for i in fn(self.noise_steps-1, 0, -1):
timestep = torch.ones(amt) * i
timestep = timestep.to(self.device)
predicted_noise = self.model(cur_img, timestep, condition)
beta, alpha, alpha_hat = self.get_schedule_at(i)
cur_img = (1 / torch.sqrt(alpha)) * (cur_img - (beta / torch.sqrt(1 - alpha_hat)) * predicted_noise)
if i > 1:
cur_img = cur_img + torch.sqrt(beta) * torch.randn_like(cur_img)
self.model.train()
# Return images sampled for each condition
return cur_img
def training_loop_iteration(self, optimizer, batch, label, criterion):
def print_(string):
for i in range(10):
print(string)
batch = batch.to(self.device)
#label = label.long() # uncomment for nn.Embedding
label = label.to(self.device)
timesteps = self.random_timesteps(batch.shape[0]).to(self.device)
noisy_batch, real_noise = self.noise_image(batch, timesteps)
if torch.isnan(noisy_batch).any() or torch.isnan(real_noise).any():
print_("NaNs detected in the noisy batch or real noise")
pred_noise = self.model(noisy_batch, timesteps, label)
if torch.isnan(pred_noise).any():
print_("NaNs detected in the predicted noise")
loss = criterion(real_noise, pred_noise)
if torch.isnan(loss).any():
print_("NaNs detected in the loss")
loss.backward()
optimizer.step()
return loss.item()