ReNoise-Inversion / src /inversion_utils.py
garibida's picture
Upload Files
d65c9b3
import torch
from random import randrange
import torch.nn.functional as F
def noise_regularization(
e_t, noise_pred_optimal, lambda_kl, lambda_ac, num_reg_steps, num_ac_rolls
):
for _outer in range(num_reg_steps):
if lambda_kl > 0:
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
l_kld = patchify_latents_kl_divergence(_var, noise_pred_optimal)
l_kld.backward()
_grad = _var.grad.detach()
_grad = torch.clip(_grad, -100, 100)
e_t = e_t - lambda_kl * _grad
if lambda_ac > 0:
for _inner in range(num_ac_rolls):
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
l_ac = auto_corr_loss(_var)
l_ac.backward()
_grad = _var.grad.detach() / num_ac_rolls
e_t = e_t - lambda_ac * _grad
e_t = e_t.detach()
return e_t
def auto_corr_loss(x, random_shift=True):
B, C, H, W = x.shape
assert B == 1
x = x.squeeze(0)
# x must be shape [C,H,W] now
reg_loss = 0.0
for ch_idx in range(x.shape[0]):
noise = x[ch_idx][None, None, :, :]
while True:
if random_shift:
roll_amount = randrange(noise.shape[2] // 2)
else:
roll_amount = 1
reg_loss += (
noise * torch.roll(noise, shifts=roll_amount, dims=2)
).mean() ** 2
reg_loss += (
noise * torch.roll(noise, shifts=roll_amount, dims=3)
).mean() ** 2
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
return reg_loss
def patchify_latents_kl_divergence(x0, x1, patch_size=4, num_channels=4):
def patchify_tensor(input_tensor):
patches = (
input_tensor.unfold(1, patch_size, patch_size)
.unfold(2, patch_size, patch_size)
.unfold(3, patch_size, patch_size)
)
patches = patches.contiguous().view(-1, num_channels, patch_size, patch_size)
return patches
x0 = patchify_tensor(x0)
x1 = patchify_tensor(x1)
kl = latents_kl_divergence(x0, x1).sum()
return kl
def latents_kl_divergence(x0, x1):
EPSILON = 1e-6
x0 = x0.view(x0.shape[0], x0.shape[1], -1)
x1 = x1.view(x1.shape[0], x1.shape[1], -1)
mu0 = x0.mean(dim=-1)
mu1 = x1.mean(dim=-1)
var0 = x0.var(dim=-1)
var1 = x1.var(dim=-1)
kl = (
torch.log((var1 + EPSILON) / (var0 + EPSILON))
+ (var0 + (mu0 - mu1) ** 2) / (var1 + EPSILON)
- 1
)
kl = torch.abs(kl).sum(dim=-1)
return kl