File size: 3,775 Bytes
e321b92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from torch.autograd import grad
from torch.optim.lr_scheduler import _LRScheduler
import torch.nn.functional as F


class WarmupCosineDecayScheduler(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, warmup_start_lr=1e-9, max_lr=1e-4, min_lr=1e-6, last_epoch=-1):
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.warmup_start_lr = warmup_start_lr
        self.max_lr = max_lr
        self.min_lr = min_lr
        super(WarmupCosineDecayScheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            # Linear warmup
            lr = self.max_lr * self.last_epoch/self.warmup_steps + (1-self.last_epoch/self.warmup_steps) * self.warmup_start_lr
        else:
            # Cosine decay
            cosine_decay = 0.5 * (1 + np.cos(torch.pi * (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)))
            decayed = (1 - self.min_lr / self.max_lr) * cosine_decay + self.min_lr / self.max_lr
            lr = self.max_lr * decayed
        return [lr for _ in self.base_lrs]


def guidance_grad(pred_shape, net, scale_noise, grad_scale=1, batch_size=32, device="cpu", save_guidance_path=None):
    # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
    sigma = 0.01 + torch.rand([batch_size, 1, 1, 1], device=device)*scale_noise
    # predict the noise residual with unet, NO grad!
    with torch.no_grad():
        # sample noise
        noise = torch.randn_like(pred_shape) * sigma
        # pred noise
        x = pred_shape+noise
        denoised = net(x, sigma)
    # w(t), sigma_t^2
    grad = torch.mean(grad_scale * (pred_shape - denoised), dim=0)  # / sigma**2
    #print(sigma.item()**2, weight.item(), torch.norm(pred_shape-denoised).item())
    #print(grad)
    grad = torch.nan_to_num(grad)

    # if save_guidance_path:
    #     with torch.no_grad():
    #         if as_latent:
    #             pred_rgb_512 = self.decode_latents(latents)

    #         # visualize predicted denoised image
    #         # The following block of code is equivalent to `predict_start_from_noise`...
    #         # see zero123_utils.py's version for a simpler implementation.
    #         alphas = self.scheduler.alphas.to(latents)
    #         total_timesteps = self.max_step - self.min_step + 1
    #         index = total_timesteps - t.to(latents.device) - 1 
    #         b = len(noise_pred)
    #         a_t = alphas[index].reshape(b,1,1,1).to(self.device)
    #         sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
    #         sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)                
    #         pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0
    #         result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))

    #         # visualize noisier image
    #         result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))

    #         # TODO: also denoise all-the-way

    #         # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
    #         viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)
    #         save_image(viz_images, save_guidance_path)

    return grad, denoised

def guidance_loss(pred_shape, loss_sde, net, grad_scale=1, device="cpu", save_guidance_path=None):
    grad = guidance_grad(pred_shape, loss_sde, net, grad_scale, device, save_guidance_path)
    targets = (pred_shape - grad).detach()
    loss = 0.5 * F.mse_loss(pred_shape.float(), targets, reduction='sum') / pred_shape.shape[0]
    return loss