Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,486 Bytes
0dc3eb6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# v-diffusion codes for DDPM inpainting. May not be compatible with k-diffusion.
# @SuspectT's inpainting codes, Feb 25 2024
# shared w/ me over Discord:
# "that's the v-diffusion inpainting with ddpm
# optimal settings were around 100 steps for the scheduler
# (ts refering to timesteps here) and resamples was 4"
import torch
from torch import nn
from typing import Callable
from tqdm import trange
import math
import sys
# from kcrowson/v-diffusion-pytorch
def t_to_alpha_sigma(t):
"""Returns the scaling factors for the clean image and for the noise, given
a timestep."""
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
#class DDPM(SamplerBase):
class DDPM():
def __init__(self, model_fn: Callable = None):
super().__init__()
def _step(
self, model_fn: Callable, x_t: torch.Tensor, step: int,
t_now: torch.Tensor, t_next: torch.Tensor,
callback: Callable, model_args, **sampler_args ) -> torch.Tensor:
alpha_now, sigma_now = t_to_alpha_sigma(t_now) # Get alpha / sigma for current timestep.
alpha_next, sigma_next = t_to_alpha_sigma(t_next) # Get alpha / sigma for next timestep.
v_t = model_fn(x_t, t_now.expand(x_t.shape[0]), **model_args) # Expand t to match batch_size which corresponds to x_t.shape[0]
eps_t = x_t * sigma_now + v_t * alpha_now
pred_t = x_t * alpha_now - v_t * sigma_now
if callback is not None:
callback({'step': step, 'x': x_t, 't': t_now, 'pred': pred_t, 'eps': eps_t})
return (pred_t * alpha_next + eps_t * sigma_next)
def _sample( self, model_fn: Callable, x_t: torch.Tensor, ts: torch.Tensor,
callback: Callable, model_args, **sampler_args ) -> torch.Tensor:
print("Using DDPM Sampler.")
steps = ts.size(0)
use_tqdm = sampler_args.get('use_tqdm')
use_range = trange if (use_tqdm if (use_tqdm != None) else False) else range
for step in use_range(steps - 1):
x_t = self._step( model_fn, x_t, step, ts[step], ts[step + 1],
lambda kwargs: callback(**dict(kwargs, steps=steps)) if(callback != None) else None,
model_args )
return x_t
def _inpaint(self,
model_fn: Callable, audio_source: torch.Tensor, mask: torch.Tensor,
ts: torch.Tensor, resamples: int, callback: Callable, model_args, **sampler_args
) -> torch.Tensor:
steps = ts.size(0)
batch_size = audio_source.size(0)
alphas, sigmas = t_to_alpha_sigma(ts)
# SHH: rescale audio_source to zero mean and unit variance
audio_source = (audio_source - audio_source.mean()) / audio_source.std()
x_t = audio_source
use_tqdm = sampler_args.get('use_tqdm')
use_range = trange if (use_tqdm if (use_tqdm != None) else False) else range
for step in use_range(steps - 1):
print("step, audio_source.min, audio_source.max, alphas[step], sigmas[step] = ", step, audio_source.min(), audio_source.max(), alphas[step], sigmas[step])
audio_source_noised = audio_source * alphas[step] + torch.randn_like(audio_source) * sigmas[step]
print("step, audio_source_noised.min, audio_source_noised.max = ", step, audio_source_noised.min(), audio_source_noised.max())
sigma_dt = torch.sqrt(sigmas[step] ** 2 - sigmas[step + 1] ** 2)
for re in range(resamples):
#x_t = audio_source_noised * mask + x_t * ~mask
x_t = audio_source_noised * mask + x_t * (1.0-mask)
# from ImageTransformerDenoiserModelV2:
# def forward(self, x, sigma, aug_cond=None, class_cond=None, mapping_cond=None):
#v_t = model_fn(x_t, ts[step].expand(batch_size), **model_args)
print("step, re, x_t.min, x_t.max , sigmas[step]= ", step, re, x_t.min(), x_t.max(), sigmas[step])
v_t = model_fn(x_t, sigmas[step].expand(batch_size), aug_cond=None, class_cond=None, mapping_cond=None)
print("step, re, v_t.min, v_t.max = ", step, re, v_t.min(), v_t.max())
if v_t.isnan().any():
print("v_t has NaNs.")
sys.exit(0)
eps_t = x_t * sigmas[step] + v_t * alphas[step]
pred_t = x_t * alphas[step] - v_t * sigmas[step]
if callback is not None:
callback({'steps': steps, 'step': step, 'x': x_t, 't': ts[step], 'pred': pred_t, 'eps': eps_t, 'res': re})
if(re < resamples - 1):
x_t = pred_t * alphas[step] + eps_t * sigmas[step + 1] + sigma_dt * torch.randn_like(x_t)
else:
x_t = pred_t * alphas[step + 1] + eps_t * sigmas[step + 1]
print("step, re, v_t.min, v_t.max, x_t.min, x_t.max = ", step, re, v_t.min(), v_t.max(), x_t.min(), x_t.max())
#sys.exit(0)
return (audio_source * mask + x_t * (1.0-mask))
def alpha_sigma_to_t(alpha, sigma):
"""Returns a timestep, given the scaling factors for the clean image and for
the noise."""
return torch.atan2(sigma, alpha) / math.pi * 2
def log_snr_to_alpha_sigma(log_snr):
"""Returns the scaling factors for the clean image and for the noise, given
the log SNR for a timestep."""
return log_snr.sigmoid().sqrt(), log_snr.neg().sigmoid().sqrt()
def get_ddpm_schedule(ddpm_t):
"""Returns timesteps for the noise schedule from the DDPM paper."""
log_snr = -torch.special.expm1(1e-4 + 10 * ddpm_t**2).log()
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
return alpha_sigma_to_t(alpha, sigma)
#class LogSchedule(SchedulerBase):
class LogSchedule():
def __init__(self, device:torch.device = None):
super().__init__(device)
def create(self, steps: int, first: float = 1, last: float = 0, device: torch.device = None, scheduler_args = {'min_log_snr': -10, 'max_log_snr': 10}) -> torch.Tensor:
ramp = torch.linspace(first, last, steps, device = device if (device != None) else self.device)
min_log_snr = scheduler_args.get('min_log_snr')
max_log_snr = scheduler_args.get('max_log_snr')
return self.get_log_schedule(
ramp,
min_log_snr if min_log_snr!=None else -10,
max_log_snr if max_log_snr!=None else 10,
)
def get_log_schedule(self, t, min_log_snr=-10, max_log_snr=10):
log_snr = t * (min_log_snr - max_log_snr) + max_log_snr
alpha = log_snr.sigmoid().sqrt()
sigma = log_snr.neg().sigmoid().sqrt()
return torch.atan2(sigma, alpha) / math.pi * 2 # this returns a timestep?
#class CrashSchedule(SchedulerBase):
class CrashSchedule():
def __init__(self, device:torch.device = None):
super().__init__(device)
def create(self, steps: int, first: float = 1, last: float = 0, device: torch.device = None, scheduler_args = None) -> torch.Tensor:
ramp = torch.linspace(first, last, steps, device = device if (device != None) else self.device)
sigma = torch.sin(ramp * math.pi / 2) ** 2
alpha = (1 - sigma**2) ** 0.5
return torch.atan2(sigma, alpha) / math.pi * 2 # this returns a timestep? |