dvq / train_dvq_diff.py
johnowhitaker's picture
Upload train_dvq_diff.py
185c7b0
import sys
sys.path.append('cclddg')
import wandb
from cclddg.data import get_paired_vqgan, tensor_to_image
from cclddg.core import UNet, Discriminator
from cclddg.ddg_context import DDG_Context
from PIL import Image
import torch
import torchvision.transforms as T
from torch_ema import ExponentialMovingAverage # pip install torch-ema
from tqdm import tqdm
import torch.nn.functional as F
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Training params
n_batches = 101000
batch_size= 5 # Lower this if hitting memory issues
lr = 5e-5
img_size=128
sr=1
n_steps=200 # Should try more
grad_accumulation_steps = 6 # batch accumulation parameter
wandb.init(project = 'dvq_diff',
config={
'n_batches':n_batches,
'batch_size':batch_size,
'lr':lr,
'img_size':img_size,
'sr':sr,
'n_steps':n_steps,
},
save_code=True)
# Context
ddg_context = DDG_Context(n_steps=n_steps, beta_min=0.005,
beta_max=0.05, device=device)
# Model
unet = UNet(image_channels=6, n_channels=128, ch_mults=(1, 1, 2, 2, 2),
is_attn=(False, False, False, True, True),
n_blocks=4, use_z=False, z_dim=8, n_z_channels=16,
use_cloob=False, n_cloob_channels=256,
n_time_channels=-1, denom_factor=1000).to(device)
unet.load_state_dict(torch.load('desert_dawn_ema_unet_020000.pt'))
if sr == 4:
# Goal is 4x SR. If image size is 256 (hq) we take 128px from lq (which is already 1/2 res) and scale to 64px then back up to 256
lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size//4), T.Resize(img_size)])
hq_tfm = T.CenterCrop(img_size)
if sr == 2:
lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size)])
hq_tfm = T.CenterCrop(img_size)
if sr == 1:
lq_tfm = T.Compose([T.Resize(img_size)])
hq_tfm = T.Compose([T.Resize(img_size)])
# Data
data = get_paired_vqgan(batch_size=batch_size)
data_iter = iter(data)
# For logging examples
n_egs = 10
eg_lq, eg_hq = next(data_iter)
eg_lq = lq_tfm(eg_lq[:n_egs]).to(device)*2-1
eg_hq = hq_tfm(eg_hq[:n_egs]).to(device)*2-1
def eg_im(eg_lq, eg_hq, ddg_context, start_t = 99):
batch_size = eg_lq.shape[0]
all_ims = [[] for _ in range(batch_size)]
# Start from noised cond_0
cond_0 = eg_lq
start_t = min(start_t, ddg_context.n_steps-1)
t = torch.tensor(start_t, dtype=torch.long).cuda()
x, n = ddg_context.q_xt_x0(cond_0, t.unsqueeze(0))
ims = []
for i in range(start_t):
t = torch.tensor(start_t-i-1, dtype=torch.long).cuda()
with torch.no_grad():
unet_input = torch.cat((x, cond_0), dim=1)
pred_noise = unet(unet_input, t.unsqueeze(0))[:,:3]
x = ddg_context.p_xt(x, pred_noise, t.unsqueeze(0))
if i%(start_t//4 - 1) == 0:
for b in range(batch_size):
all_ims[b].append(tensor_to_image(x[b].cpu()))
# HQ target:
for b in range(batch_size):
all_ims[b].append(tensor_to_image(eg_hq[b].cpu()))
# Input/cond:
for b in range(batch_size):
all_ims[b].append(tensor_to_image(cond_0[b].cpu()))
image = Image.new('RGB', size=(img_size*7, batch_size*img_size))
for i in range(7):
for b in range(batch_size):
image.paste(all_ims[b][i], (i*img_size, b*img_size))
return image
# Training Loop
losses = [] # Store losses for later plotting
optim = torch.optim.RMSprop(unet.parameters(), lr=lr) # Optimizer
ema = ExponentialMovingAverage(unet.parameters(), decay=0.995) # EMA
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9)
for i in tqdm(range(0, n_batches)): # Run through the dataset
# Get a batch
try:
lq, hq = next(data_iter)
except:
pass
lq = lq_tfm(lq).to(device)*2-1
hq = hq_tfm(hq).to(device)*2-1
batch_size=lq.shape[0]
x0 = hq
cond_0 = lq
t = torch.randint(1, ddg_context.n_steps, (batch_size,), dtype=torch.long).to(device) # Random 't's
xt, noise = ddg_context.q_xt_x0(x0, t) # Get the noised images (xt) and the noise (our target)
unet_input = torch.cat((xt, cond_0), dim=1) # Combine with cond
pred_noise = unet(unet_input, t)[:,:3] # Run xt through the network to get its predictions
loss = F.mse_loss(noise.float(), pred_noise) # Compare the predictions with the targets
losses.append(loss.item()) # Store the loss for later viewing
wandb.log({'Loss':loss.item()}) # Log to wandb
loss.backward() # Backpropagate the loss
if i % grad_accumulation_steps == 0:
optim.step() # Update the network parameters
optim.zero_grad() # Zero the gradients
ema.update() # Update the moving average with the new parameters from the last optimizer step
if i % 2000 == 0:
with torch.no_grad():
wandb.log({'Examples @120':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 120))})
wandb.log({'Examples @199':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 199))})
wandb.log({'Random Examples @120':wandb.Image(eg_im(lq, hq, ddg_context, start_t = 120))})
if i % 20000 == 0:
torch.save(unet.state_dict(), f'unet_{i:06}.pt')
with ema.average_parameters():
torch.save(unet.state_dict(), f'ema_unet_{i:06}.pt')
if (i+1)%4000 == 0:
scheduler.step()
wandb.log({'lr':optim.param_groups[0]['lr']})