import sys sys.path.append('cclddg') import wandb from cclddg.data import get_paired_vqgan, tensor_to_image from cclddg.core import UNet, Discriminator from cclddg.ddg_context import DDG_Context from PIL import Image import torch import torchvision.transforms as T from torch_ema import ExponentialMovingAverage # pip install torch-ema from tqdm import tqdm import torch.nn.functional as F device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Training params n_batches = 101000 batch_size= 5 # Lower this if hitting memory issues lr = 5e-5 img_size=128 sr=1 n_steps=200 # Should try more grad_accumulation_steps = 6 # batch accumulation parameter wandb.init(project = 'dvq_diff', config={ 'n_batches':n_batches, 'batch_size':batch_size, 'lr':lr, 'img_size':img_size, 'sr':sr, 'n_steps':n_steps, }, save_code=True) # Context ddg_context = DDG_Context(n_steps=n_steps, beta_min=0.005, beta_max=0.05, device=device) # Model unet = UNet(image_channels=6, n_channels=128, ch_mults=(1, 1, 2, 2, 2), is_attn=(False, False, False, True, True), n_blocks=4, use_z=False, z_dim=8, n_z_channels=16, use_cloob=False, n_cloob_channels=256, n_time_channels=-1, denom_factor=1000).to(device) unet.load_state_dict(torch.load('desert_dawn_ema_unet_020000.pt')) if sr == 4: # Goal is 4x SR. If image size is 256 (hq) we take 128px from lq (which is already 1/2 res) and scale to 64px then back up to 256 lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size//4), T.Resize(img_size)]) hq_tfm = T.CenterCrop(img_size) if sr == 2: lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size)]) hq_tfm = T.CenterCrop(img_size) if sr == 1: lq_tfm = T.Compose([T.Resize(img_size)]) hq_tfm = T.Compose([T.Resize(img_size)]) # Data data = get_paired_vqgan(batch_size=batch_size) data_iter = iter(data) # For logging examples n_egs = 10 eg_lq, eg_hq = next(data_iter) eg_lq = lq_tfm(eg_lq[:n_egs]).to(device)*2-1 eg_hq = hq_tfm(eg_hq[:n_egs]).to(device)*2-1 def eg_im(eg_lq, eg_hq, ddg_context, start_t = 99): batch_size = eg_lq.shape[0] all_ims = [[] for _ in range(batch_size)] # Start from noised cond_0 cond_0 = eg_lq start_t = min(start_t, ddg_context.n_steps-1) t = torch.tensor(start_t, dtype=torch.long).cuda() x, n = ddg_context.q_xt_x0(cond_0, t.unsqueeze(0)) ims = [] for i in range(start_t): t = torch.tensor(start_t-i-1, dtype=torch.long).cuda() with torch.no_grad(): unet_input = torch.cat((x, cond_0), dim=1) pred_noise = unet(unet_input, t.unsqueeze(0))[:,:3] x = ddg_context.p_xt(x, pred_noise, t.unsqueeze(0)) if i%(start_t//4 - 1) == 0: for b in range(batch_size): all_ims[b].append(tensor_to_image(x[b].cpu())) # HQ target: for b in range(batch_size): all_ims[b].append(tensor_to_image(eg_hq[b].cpu())) # Input/cond: for b in range(batch_size): all_ims[b].append(tensor_to_image(cond_0[b].cpu())) image = Image.new('RGB', size=(img_size*7, batch_size*img_size)) for i in range(7): for b in range(batch_size): image.paste(all_ims[b][i], (i*img_size, b*img_size)) return image # Training Loop losses = [] # Store losses for later plotting optim = torch.optim.RMSprop(unet.parameters(), lr=lr) # Optimizer ema = ExponentialMovingAverage(unet.parameters(), decay=0.995) # EMA scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9) for i in tqdm(range(0, n_batches)): # Run through the dataset # Get a batch try: lq, hq = next(data_iter) except: pass lq = lq_tfm(lq).to(device)*2-1 hq = hq_tfm(hq).to(device)*2-1 batch_size=lq.shape[0] x0 = hq cond_0 = lq t = torch.randint(1, ddg_context.n_steps, (batch_size,), dtype=torch.long).to(device) # Random 't's xt, noise = ddg_context.q_xt_x0(x0, t) # Get the noised images (xt) and the noise (our target) unet_input = torch.cat((xt, cond_0), dim=1) # Combine with cond pred_noise = unet(unet_input, t)[:,:3] # Run xt through the network to get its predictions loss = F.mse_loss(noise.float(), pred_noise) # Compare the predictions with the targets losses.append(loss.item()) # Store the loss for later viewing wandb.log({'Loss':loss.item()}) # Log to wandb loss.backward() # Backpropagate the loss if i % grad_accumulation_steps == 0: optim.step() # Update the network parameters optim.zero_grad() # Zero the gradients ema.update() # Update the moving average with the new parameters from the last optimizer step if i % 2000 == 0: with torch.no_grad(): wandb.log({'Examples @120':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 120))}) wandb.log({'Examples @199':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 199))}) wandb.log({'Random Examples @120':wandb.Image(eg_im(lq, hq, ddg_context, start_t = 120))}) if i % 20000 == 0: torch.save(unet.state_dict(), f'unet_{i:06}.pt') with ema.average_parameters(): torch.save(unet.state_dict(), f'ema_unet_{i:06}.pt') if (i+1)%4000 == 0: scheduler.step() wandb.log({'lr':optim.param_groups[0]['lr']})