File size: 5,565 Bytes
185c7b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import sys
sys.path.append('cclddg')
import wandb
from cclddg.data import get_paired_vqgan, tensor_to_image
from cclddg.core import UNet, Discriminator
from cclddg.ddg_context import DDG_Context
from PIL import Image
import torch
import torchvision.transforms as T
from torch_ema import ExponentialMovingAverage # pip install torch-ema
from tqdm import tqdm
import torch.nn.functional as F
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Training params
n_batches = 101000 
batch_size= 5 # Lower this if hitting memory issues
lr = 5e-5
img_size=128
sr=1
n_steps=200 # Should try more
grad_accumulation_steps = 6  # batch accumulation parameter

wandb.init(project = 'dvq_diff',
          config={
                'n_batches':n_batches,
                'batch_size':batch_size,
                'lr':lr,
                'img_size':img_size,
                'sr':sr,
                'n_steps':n_steps,
          },
          save_code=True)

# Context
ddg_context = DDG_Context(n_steps=n_steps, beta_min=0.005, 
                              beta_max=0.05, device=device)

# Model
unet = UNet(image_channels=6, n_channels=128, ch_mults=(1, 1, 2, 2, 2), 
            is_attn=(False, False, False, True, True),
            n_blocks=4, use_z=False, z_dim=8, n_z_channels=16, 
            use_cloob=False, n_cloob_channels=256, 
            n_time_channels=-1, denom_factor=1000).to(device)
unet.load_state_dict(torch.load('desert_dawn_ema_unet_020000.pt'))


if sr == 4:
    # Goal is 4x SR. If image size is 256 (hq) we take 128px from lq (which is already 1/2 res) and scale to 64px then back up to 256
    lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size//4), T.Resize(img_size)])
    hq_tfm = T.CenterCrop(img_size)
if sr == 2:
    lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size)])
    hq_tfm = T.CenterCrop(img_size)
if sr == 1:
    lq_tfm = T.Compose([T.Resize(img_size)])
    hq_tfm = T.Compose([T.Resize(img_size)])


# Data
data = get_paired_vqgan(batch_size=batch_size)
data_iter = iter(data)

# For logging examples
n_egs = 10
eg_lq, eg_hq = next(data_iter)
eg_lq = lq_tfm(eg_lq[:n_egs]).to(device)*2-1
eg_hq = hq_tfm(eg_hq[:n_egs]).to(device)*2-1
def eg_im(eg_lq, eg_hq, ddg_context, start_t = 99):
    batch_size = eg_lq.shape[0]
    all_ims = [[] for _ in range(batch_size)]
    
    # Start from noised cond_0
    cond_0 = eg_lq
    start_t = min(start_t, ddg_context.n_steps-1)
    t = torch.tensor(start_t, dtype=torch.long).cuda()
    x, n = ddg_context.q_xt_x0(cond_0, t.unsqueeze(0))
    ims = []
    for i in range(start_t):
        t = torch.tensor(start_t-i-1, dtype=torch.long).cuda()
        with torch.no_grad():
            unet_input = torch.cat((x, cond_0), dim=1)
            pred_noise = unet(unet_input, t.unsqueeze(0))[:,:3] 
            x = ddg_context.p_xt(x, pred_noise, t.unsqueeze(0))
            if i%(start_t//4 - 1) == 0:
                for b in range(batch_size):
                    all_ims[b].append(tensor_to_image(x[b].cpu()))
    
    # HQ target:
    for b in range(batch_size):
        all_ims[b].append(tensor_to_image(eg_hq[b].cpu()))
    # Input/cond:
    for b in range(batch_size):
        all_ims[b].append(tensor_to_image(cond_0[b].cpu()))

    image = Image.new('RGB', size=(img_size*7, batch_size*img_size))
    for i in range(7):
        for b in range(batch_size):
            image.paste(all_ims[b][i], (i*img_size, b*img_size))
    
    return image

# Training Loop
losses = [] # Store losses for later plotting
optim = torch.optim.RMSprop(unet.parameters(), lr=lr) # Optimizer
ema = ExponentialMovingAverage(unet.parameters(), decay=0.995) # EMA
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9)

for i in tqdm(range(0, n_batches)): # Run through the dataset

    # Get a batch
    try:
        lq, hq = next(data_iter)
    except:
        pass
    lq = lq_tfm(lq).to(device)*2-1
    hq = hq_tfm(hq).to(device)*2-1
    batch_size=lq.shape[0]
    
    
    x0 = hq
    cond_0 = lq
    t = torch.randint(1, ddg_context.n_steps, (batch_size,), dtype=torch.long).to(device) # Random 't's 
    xt, noise = ddg_context.q_xt_x0(x0, t) # Get the noised images (xt) and the noise (our target)
    unet_input = torch.cat((xt, cond_0), dim=1) # Combine with cond
    pred_noise = unet(unet_input, t)[:,:3] # Run xt through the network to get its predictions
    loss = F.mse_loss(noise.float(), pred_noise) # Compare the predictions with the targets
    losses.append(loss.item()) # Store the loss for later viewing
    wandb.log({'Loss':loss.item()}) # Log to wandb
    loss.backward() # Backpropagate the loss
    
    if i % grad_accumulation_steps == 0:
        optim.step() # Update the network parameters
        optim.zero_grad() # Zero the gradients
        ema.update() # Update the moving average with the new parameters from the last optimizer step
    
    if i % 2000 == 0:
        with torch.no_grad():
            wandb.log({'Examples @120':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 120))})
            wandb.log({'Examples @199':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 199))})
            wandb.log({'Random Examples @120':wandb.Image(eg_im(lq, hq, ddg_context, start_t = 120))})
        
    if i % 20000 == 0:
        torch.save(unet.state_dict(), f'unet_{i:06}.pt')
        with ema.average_parameters():
            torch.save(unet.state_dict(), f'ema_unet_{i:06}.pt')
        
    if (i+1)%4000 == 0:
        scheduler.step()
        wandb.log({'lr':optim.param_groups[0]['lr']})