STAR / video_to_video /diffusion /solvers_sdedit.py
xierui.0097
Add application file
f0e9666
raw
history blame
6.94 kB
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torchsde
from tqdm.auto import trange
from video_to_video.utils.logger import get_logger
logger = get_logger()
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
"""
Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step.
"""
if not eta:
return sigma_to, 0.
sigma_up = min(
sigma_to,
eta * (
sigma_to**2 * # noqa
(sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5)
sigma_down = (sigma_to**2 - sigma_up**2)**0.5
return sigma_down, sigma_up
def get_scalings(sigma):
c_out = -sigma
c_in = 1 / (sigma**2 + 1.**2)**0.5
return c_out, c_in
@torch.no_grad()
def sample_heun(noise,
model,
sigmas,
s_churn=0.,
s_tmin=0.,
s_tmax=float('inf'),
s_noise=1.,
show_progress=True):
"""
Implements Algorithm 2 (Heun steps) from Karras et al. (2022).
"""
x = noise * sigmas[0]
for i in trange(len(sigmas) - 1, disable=not show_progress):
gamma = 0.
if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float('inf'):
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5
if sigmas[i] == float('inf'):
# Euler method
denoised = model(noise, sigma_hat)
x = denoised + sigmas[i + 1] * (gamma + 1) * noise
else:
_, c_in = get_scalings(sigma_hat)
denoised = model(x * c_in, sigma_hat)
d = (x - denoised) / sigma_hat
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
_, c_in = get_scalings(sigmas[i + 1])
denoised_2 = model(x_2 * c_in, sigmas[i + 1])
d_2 = (x_2 - denoised_2) / sigmas[i + 1]
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x
class BatchedBrownianTree:
"""
A wrapper around torchsde.BrownianTree that enables batches of entropy.
"""
def __init__(self, x, t0, t1, seed=None, **kwargs):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get('w0', torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2**63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
self.trees = [
torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs)
for s in seed
]
@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (
self.sign * sign)
return w if self.batched else w[0]
class BrownianTreeNoiseSampler:
"""
A noise sampler backed by a torchsde.BrownianTree.
Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
random samples.
sigma_min (float): The low end of the valid interval.
sigma_max (float): The high end of the valid interval.
seed (int or List[int]): The random seed. If a list of seeds is
supplied instead of a single integer, then the noise sampler will
use one BrownianTree per batch item, each with its own seed.
transform (callable): A function that maps sigma to the sampler's
internal timestep.
"""
def __init__(self,
x,
sigma_min,
sigma_max,
seed=None,
transform=lambda x: x):
self.transform = transform
t0 = self.transform(torch.as_tensor(sigma_min))
t1 = self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed)
def __call__(self, sigma, sigma_next):
t0 = self.transform(torch.as_tensor(sigma))
t1 = self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
@torch.no_grad()
def sample_dpmpp_2m_sde(noise,
model,
sigmas,
eta=1.,
s_noise=1.,
solver_type='midpoint',
show_progress=True,
variant_info=None):
"""
DPM-Solver++ (2M) SDE.
"""
assert solver_type in {'heun', 'midpoint'}
x = noise * sigmas[0]
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[
sigmas < float('inf')].max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
old_denoised = None
h_last = None
for i in trange(len(sigmas) - 1, disable=not show_progress):
logger.info(f'step: {i}')
if sigmas[i] == float('inf'):
# Euler method
denoised = model(noise, sigmas[i], variant_info=variant_info)
x = denoised + sigmas[i + 1] * noise
else:
_, c_in = get_scalings(sigmas[i])
denoised = model(x * c_in, sigmas[i], variant_info=variant_info)
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \
(-h - eta_h).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \
(1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * \
(1 / r) * (denoised - old_denoised)
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[
i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
if variant_info is not None and variant_info.get('type') == 'variant1':
x_long, x_short = x.chunk(2, dim=0)
x = x_long * (1-variant_info['alpha']) + x_short * variant_info['alpha']
return x