Spaces:
Sleeping
Sleeping
File size: 3,775 Bytes
e321b92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch
from torch.autograd import grad
from torch.optim.lr_scheduler import _LRScheduler
import torch.nn.functional as F
class WarmupCosineDecayScheduler(_LRScheduler):
def __init__(self, optimizer, warmup_steps, total_steps, warmup_start_lr=1e-9, max_lr=1e-4, min_lr=1e-6, last_epoch=-1):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.warmup_start_lr = warmup_start_lr
self.max_lr = max_lr
self.min_lr = min_lr
super(WarmupCosineDecayScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
# Linear warmup
lr = self.max_lr * self.last_epoch/self.warmup_steps + (1-self.last_epoch/self.warmup_steps) * self.warmup_start_lr
else:
# Cosine decay
cosine_decay = 0.5 * (1 + np.cos(torch.pi * (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)))
decayed = (1 - self.min_lr / self.max_lr) * cosine_decay + self.min_lr / self.max_lr
lr = self.max_lr * decayed
return [lr for _ in self.base_lrs]
def guidance_grad(pred_shape, net, scale_noise, grad_scale=1, batch_size=32, device="cpu", save_guidance_path=None):
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
sigma = 0.01 + torch.rand([batch_size, 1, 1, 1], device=device)*scale_noise
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# sample noise
noise = torch.randn_like(pred_shape) * sigma
# pred noise
x = pred_shape+noise
denoised = net(x, sigma)
# w(t), sigma_t^2
grad = torch.mean(grad_scale * (pred_shape - denoised), dim=0) # / sigma**2
#print(sigma.item()**2, weight.item(), torch.norm(pred_shape-denoised).item())
#print(grad)
grad = torch.nan_to_num(grad)
# if save_guidance_path:
# with torch.no_grad():
# if as_latent:
# pred_rgb_512 = self.decode_latents(latents)
# # visualize predicted denoised image
# # The following block of code is equivalent to `predict_start_from_noise`...
# # see zero123_utils.py's version for a simpler implementation.
# alphas = self.scheduler.alphas.to(latents)
# total_timesteps = self.max_step - self.min_step + 1
# index = total_timesteps - t.to(latents.device) - 1
# b = len(noise_pred)
# a_t = alphas[index].reshape(b,1,1,1).to(self.device)
# sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
# sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)
# pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0
# result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))
# # visualize noisier image
# result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))
# # TODO: also denoise all-the-way
# # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
# viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)
# save_image(viz_images, save_guidance_path)
return grad, denoised
def guidance_loss(pred_shape, loss_sde, net, grad_scale=1, device="cpu", save_guidance_path=None):
grad = guidance_grad(pred_shape, loss_sde, net, grad_scale, device, save_guidance_path)
targets = (pred_shape - grad).detach()
loss = 0.5 * F.mse_loss(pred_shape.float(), targets, reduction='sum') / pred_shape.shape[0]
return loss |