File size: 4,280 Bytes
d661b19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from pathlib import Path
import json
from math import sqrt
import numpy as np
import torch
from abc import ABCMeta, abstractmethod
class ScoreAdapter(metaclass=ABCMeta):
@abstractmethod
def denoise(self, xs, σ, **kwargs):
pass
def score(self, xs, σ, **kwargs):
Ds = self.denoise(xs, σ, **kwargs)
grad_log_p_t = (Ds - xs) / (σ ** 2)
return grad_log_p_t
@abstractmethod
def data_shape(self):
return (3, 256, 256) # for example
def samps_centered(self):
# if centered, samples expected to be in range [-1, 1], else [0, 1]
return True
@property
@abstractmethod
def σ_max(self):
pass
@property
@abstractmethod
def σ_min(self):
pass
def cond_info(self, batch_size):
return {}
@abstractmethod
def unet_is_cond(self):
return False
@abstractmethod
def use_cls_guidance(self):
return False # most models do not use cls guidance
def classifier_grad(self, xs, σ, ys):
raise NotImplementedError()
@abstractmethod
def snap_t_to_nearest_tick(self, t):
# need to confirm for each model; continuous time model doesn't need this
return t, None
@property
def device(self):
return self._device
def checkpoint_root(self):
"""the path at which the pretrained checkpoints are stored"""
with Path(__file__).resolve().with_name("env.json").open("r") as f:
root = json.load(f)
return root
def karras_t_schedule(ρ=7, N=10, σ_max=80, σ_min=0.002):
ts = []
for i in range(N):
t = (
σ_max ** (1 / ρ) + (i / (N - 1)) * (σ_min ** (1 / ρ) - σ_max ** (1 / ρ))
) ** ρ
ts.append(t)
return ts
def power_schedule(σ_max, σ_min, num_stages):
σs = np.exp(np.linspace(np.log(σ_max), np.log(σ_min), num_stages))
return σs
class Karras():
@classmethod
@torch.no_grad()
def inference(
cls, model, batch_size, num_t, *,
σ_max=80, cls_scaling=1,
init_xs=None, heun=True,
langevin=False,
S_churn=80, S_min=0.05, S_max=50, S_noise=1.003,
):
σ_max = min(σ_max, model.σ_max)
σ_min = model.σ_min
ts = karras_t_schedule(ρ=7, N=num_t, σ_max=σ_max, σ_min=σ_min)
assert len(ts) == num_t
ts = [model.snap_t_to_nearest_tick(t)[0] for t in ts]
ts.append(0) # 0 is the destination
σ_max = ts[0]
cond_inputs = model.cond_info(batch_size)
def compute_step(xs, σ):
grad_log_p_t = model.score(
xs, σ, **(cond_inputs if model.unet_is_cond() else {})
)
if model.use_cls_guidance():
grad_cls = model.classifier_grad(xs, σ, cond_inputs["y"])
grad_cls = grad_cls * cls_scaling
grad_log_p_t += grad_cls
d_i = -1 * σ * grad_log_p_t
return d_i
if init_xs is not None:
xs = init_xs.to(model.device)
else:
xs = σ_max * torch.randn(
batch_size, *model.data_shape(), device=model.device
)
yield xs
for i in range(num_t):
t_i = ts[i]
if langevin and (S_min < t_i and t_i < S_max):
xs, t_i = cls.noise_backward_in_time(
model, xs, t_i, S_noise, S_churn / num_t
)
Δt = ts[i+1] - t_i
d_1 = compute_step(xs, σ=t_i)
xs_1 = xs + Δt * d_1
# Heun's 2nd order method; don't apply on the last step
if (not heun) or (ts[i+1] == 0):
xs = xs_1
else:
d_2 = compute_step(xs_1, σ=ts[i+1])
xs = xs + Δt * (d_1 + d_2) / 2
yield xs
@staticmethod
def noise_backward_in_time(model, xs, t_i, S_noise, S_churn_i):
n = S_noise * torch.randn_like(xs)
γ_i = min(sqrt(2)-1, S_churn_i)
t_i_hat = t_i * (1 + γ_i)
t_i_hat = model.snap_t_to_nearest_tick(t_i_hat)[0]
xs = xs + n * sqrt(t_i_hat ** 2 - t_i ** 2)
return xs, t_i_hat
def test():
pass
if __name__ == "__main__":
test()
|