File size: 5,565 Bytes
185c7b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import sys
sys.path.append('cclddg')
import wandb
from cclddg.data import get_paired_vqgan, tensor_to_image
from cclddg.core import UNet, Discriminator
from cclddg.ddg_context import DDG_Context
from PIL import Image
import torch
import torchvision.transforms as T
from torch_ema import ExponentialMovingAverage # pip install torch-ema
from tqdm import tqdm
import torch.nn.functional as F
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Training params
n_batches = 101000
batch_size= 5 # Lower this if hitting memory issues
lr = 5e-5
img_size=128
sr=1
n_steps=200 # Should try more
grad_accumulation_steps = 6 # batch accumulation parameter
wandb.init(project = 'dvq_diff',
config={
'n_batches':n_batches,
'batch_size':batch_size,
'lr':lr,
'img_size':img_size,
'sr':sr,
'n_steps':n_steps,
},
save_code=True)
# Context
ddg_context = DDG_Context(n_steps=n_steps, beta_min=0.005,
beta_max=0.05, device=device)
# Model
unet = UNet(image_channels=6, n_channels=128, ch_mults=(1, 1, 2, 2, 2),
is_attn=(False, False, False, True, True),
n_blocks=4, use_z=False, z_dim=8, n_z_channels=16,
use_cloob=False, n_cloob_channels=256,
n_time_channels=-1, denom_factor=1000).to(device)
unet.load_state_dict(torch.load('desert_dawn_ema_unet_020000.pt'))
if sr == 4:
# Goal is 4x SR. If image size is 256 (hq) we take 128px from lq (which is already 1/2 res) and scale to 64px then back up to 256
lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size//4), T.Resize(img_size)])
hq_tfm = T.CenterCrop(img_size)
if sr == 2:
lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size)])
hq_tfm = T.CenterCrop(img_size)
if sr == 1:
lq_tfm = T.Compose([T.Resize(img_size)])
hq_tfm = T.Compose([T.Resize(img_size)])
# Data
data = get_paired_vqgan(batch_size=batch_size)
data_iter = iter(data)
# For logging examples
n_egs = 10
eg_lq, eg_hq = next(data_iter)
eg_lq = lq_tfm(eg_lq[:n_egs]).to(device)*2-1
eg_hq = hq_tfm(eg_hq[:n_egs]).to(device)*2-1
def eg_im(eg_lq, eg_hq, ddg_context, start_t = 99):
batch_size = eg_lq.shape[0]
all_ims = [[] for _ in range(batch_size)]
# Start from noised cond_0
cond_0 = eg_lq
start_t = min(start_t, ddg_context.n_steps-1)
t = torch.tensor(start_t, dtype=torch.long).cuda()
x, n = ddg_context.q_xt_x0(cond_0, t.unsqueeze(0))
ims = []
for i in range(start_t):
t = torch.tensor(start_t-i-1, dtype=torch.long).cuda()
with torch.no_grad():
unet_input = torch.cat((x, cond_0), dim=1)
pred_noise = unet(unet_input, t.unsqueeze(0))[:,:3]
x = ddg_context.p_xt(x, pred_noise, t.unsqueeze(0))
if i%(start_t//4 - 1) == 0:
for b in range(batch_size):
all_ims[b].append(tensor_to_image(x[b].cpu()))
# HQ target:
for b in range(batch_size):
all_ims[b].append(tensor_to_image(eg_hq[b].cpu()))
# Input/cond:
for b in range(batch_size):
all_ims[b].append(tensor_to_image(cond_0[b].cpu()))
image = Image.new('RGB', size=(img_size*7, batch_size*img_size))
for i in range(7):
for b in range(batch_size):
image.paste(all_ims[b][i], (i*img_size, b*img_size))
return image
# Training Loop
losses = [] # Store losses for later plotting
optim = torch.optim.RMSprop(unet.parameters(), lr=lr) # Optimizer
ema = ExponentialMovingAverage(unet.parameters(), decay=0.995) # EMA
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9)
for i in tqdm(range(0, n_batches)): # Run through the dataset
# Get a batch
try:
lq, hq = next(data_iter)
except:
pass
lq = lq_tfm(lq).to(device)*2-1
hq = hq_tfm(hq).to(device)*2-1
batch_size=lq.shape[0]
x0 = hq
cond_0 = lq
t = torch.randint(1, ddg_context.n_steps, (batch_size,), dtype=torch.long).to(device) # Random 't's
xt, noise = ddg_context.q_xt_x0(x0, t) # Get the noised images (xt) and the noise (our target)
unet_input = torch.cat((xt, cond_0), dim=1) # Combine with cond
pred_noise = unet(unet_input, t)[:,:3] # Run xt through the network to get its predictions
loss = F.mse_loss(noise.float(), pred_noise) # Compare the predictions with the targets
losses.append(loss.item()) # Store the loss for later viewing
wandb.log({'Loss':loss.item()}) # Log to wandb
loss.backward() # Backpropagate the loss
if i % grad_accumulation_steps == 0:
optim.step() # Update the network parameters
optim.zero_grad() # Zero the gradients
ema.update() # Update the moving average with the new parameters from the last optimizer step
if i % 2000 == 0:
with torch.no_grad():
wandb.log({'Examples @120':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 120))})
wandb.log({'Examples @199':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 199))})
wandb.log({'Random Examples @120':wandb.Image(eg_im(lq, hq, ddg_context, start_t = 120))})
if i % 20000 == 0:
torch.save(unet.state_dict(), f'unet_{i:06}.pt')
with ema.average_parameters():
torch.save(unet.state_dict(), f'ema_unet_{i:06}.pt')
if (i+1)%4000 == 0:
scheduler.step()
wandb.log({'lr':optim.param_groups[0]['lr']})
|