|
import sys |
|
sys.path.append('cclddg') |
|
import wandb |
|
from cclddg.data import get_paired_vqgan, tensor_to_image |
|
from cclddg.core import UNet, Discriminator |
|
from cclddg.ddg_context import DDG_Context |
|
from PIL import Image |
|
import torch |
|
import torchvision.transforms as T |
|
from torch_ema import ExponentialMovingAverage |
|
from tqdm import tqdm |
|
import torch.nn.functional as F |
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
n_batches = 101000 |
|
batch_size= 5 |
|
lr = 5e-5 |
|
img_size=128 |
|
sr=1 |
|
n_steps=200 |
|
grad_accumulation_steps = 6 |
|
|
|
wandb.init(project = 'dvq_diff', |
|
config={ |
|
'n_batches':n_batches, |
|
'batch_size':batch_size, |
|
'lr':lr, |
|
'img_size':img_size, |
|
'sr':sr, |
|
'n_steps':n_steps, |
|
}, |
|
save_code=True) |
|
|
|
|
|
ddg_context = DDG_Context(n_steps=n_steps, beta_min=0.005, |
|
beta_max=0.05, device=device) |
|
|
|
|
|
unet = UNet(image_channels=6, n_channels=128, ch_mults=(1, 1, 2, 2, 2), |
|
is_attn=(False, False, False, True, True), |
|
n_blocks=4, use_z=False, z_dim=8, n_z_channels=16, |
|
use_cloob=False, n_cloob_channels=256, |
|
n_time_channels=-1, denom_factor=1000).to(device) |
|
unet.load_state_dict(torch.load('desert_dawn_ema_unet_020000.pt')) |
|
|
|
|
|
if sr == 4: |
|
|
|
lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size//4), T.Resize(img_size)]) |
|
hq_tfm = T.CenterCrop(img_size) |
|
if sr == 2: |
|
lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size)]) |
|
hq_tfm = T.CenterCrop(img_size) |
|
if sr == 1: |
|
lq_tfm = T.Compose([T.Resize(img_size)]) |
|
hq_tfm = T.Compose([T.Resize(img_size)]) |
|
|
|
|
|
|
|
data = get_paired_vqgan(batch_size=batch_size) |
|
data_iter = iter(data) |
|
|
|
|
|
n_egs = 10 |
|
eg_lq, eg_hq = next(data_iter) |
|
eg_lq = lq_tfm(eg_lq[:n_egs]).to(device)*2-1 |
|
eg_hq = hq_tfm(eg_hq[:n_egs]).to(device)*2-1 |
|
def eg_im(eg_lq, eg_hq, ddg_context, start_t = 99): |
|
batch_size = eg_lq.shape[0] |
|
all_ims = [[] for _ in range(batch_size)] |
|
|
|
|
|
cond_0 = eg_lq |
|
start_t = min(start_t, ddg_context.n_steps-1) |
|
t = torch.tensor(start_t, dtype=torch.long).cuda() |
|
x, n = ddg_context.q_xt_x0(cond_0, t.unsqueeze(0)) |
|
ims = [] |
|
for i in range(start_t): |
|
t = torch.tensor(start_t-i-1, dtype=torch.long).cuda() |
|
with torch.no_grad(): |
|
unet_input = torch.cat((x, cond_0), dim=1) |
|
pred_noise = unet(unet_input, t.unsqueeze(0))[:,:3] |
|
x = ddg_context.p_xt(x, pred_noise, t.unsqueeze(0)) |
|
if i%(start_t//4 - 1) == 0: |
|
for b in range(batch_size): |
|
all_ims[b].append(tensor_to_image(x[b].cpu())) |
|
|
|
|
|
for b in range(batch_size): |
|
all_ims[b].append(tensor_to_image(eg_hq[b].cpu())) |
|
|
|
for b in range(batch_size): |
|
all_ims[b].append(tensor_to_image(cond_0[b].cpu())) |
|
|
|
image = Image.new('RGB', size=(img_size*7, batch_size*img_size)) |
|
for i in range(7): |
|
for b in range(batch_size): |
|
image.paste(all_ims[b][i], (i*img_size, b*img_size)) |
|
|
|
return image |
|
|
|
|
|
losses = [] |
|
optim = torch.optim.RMSprop(unet.parameters(), lr=lr) |
|
ema = ExponentialMovingAverage(unet.parameters(), decay=0.995) |
|
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9) |
|
|
|
for i in tqdm(range(0, n_batches)): |
|
|
|
|
|
try: |
|
lq, hq = next(data_iter) |
|
except: |
|
pass |
|
lq = lq_tfm(lq).to(device)*2-1 |
|
hq = hq_tfm(hq).to(device)*2-1 |
|
batch_size=lq.shape[0] |
|
|
|
|
|
x0 = hq |
|
cond_0 = lq |
|
t = torch.randint(1, ddg_context.n_steps, (batch_size,), dtype=torch.long).to(device) |
|
xt, noise = ddg_context.q_xt_x0(x0, t) |
|
unet_input = torch.cat((xt, cond_0), dim=1) |
|
pred_noise = unet(unet_input, t)[:,:3] |
|
loss = F.mse_loss(noise.float(), pred_noise) |
|
losses.append(loss.item()) |
|
wandb.log({'Loss':loss.item()}) |
|
loss.backward() |
|
|
|
if i % grad_accumulation_steps == 0: |
|
optim.step() |
|
optim.zero_grad() |
|
ema.update() |
|
|
|
if i % 2000 == 0: |
|
with torch.no_grad(): |
|
wandb.log({'Examples @120':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 120))}) |
|
wandb.log({'Examples @199':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 199))}) |
|
wandb.log({'Random Examples @120':wandb.Image(eg_im(lq, hq, ddg_context, start_t = 120))}) |
|
|
|
if i % 20000 == 0: |
|
torch.save(unet.state_dict(), f'unet_{i:06}.pt') |
|
with ema.average_parameters(): |
|
torch.save(unet.state_dict(), f'ema_unet_{i:06}.pt') |
|
|
|
if (i+1)%4000 == 0: |
|
scheduler.step() |
|
wandb.log({'lr':optim.param_groups[0]['lr']}) |
|
|