Spaces:
Sleeping
Sleeping
drscotthawley
commited on
Commit
•
b887586
1
Parent(s):
b46aa4b
needed to add sample.py
Browse files
sample.py
ADDED
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Code by Kat Crowson in k-diffusion repo, modified by Scott H Hawley (SHH)
|
4 |
+
|
5 |
+
"""Samples from k-diffusion models."""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import accelerate
|
11 |
+
import safetensors.torch as safetorch
|
12 |
+
import torch
|
13 |
+
from tqdm import trange, tqdm
|
14 |
+
from PIL import Image
|
15 |
+
from torchvision import transforms
|
16 |
+
|
17 |
+
import k_diffusion as K
|
18 |
+
|
19 |
+
from control_toys.v_diffusion import DDPM, LogSchedule, CrashSchedule
|
20 |
+
#CHORD_BORDER = 8 # chord border size in pixels
|
21 |
+
from control_toys.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
|
22 |
+
|
23 |
+
|
24 |
+
# ---- my mangled sampler that includes repaint
|
25 |
+
import torchsde
|
26 |
+
|
27 |
+
class BatchedBrownianTree:
|
28 |
+
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
29 |
+
|
30 |
+
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
31 |
+
t0, t1, self.sign = self.sort(t0, t1)
|
32 |
+
w0 = kwargs.get('w0', torch.zeros_like(x))
|
33 |
+
if seed is None:
|
34 |
+
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
35 |
+
self.batched = True
|
36 |
+
try:
|
37 |
+
assert len(seed) == x.shape[0]
|
38 |
+
w0 = w0[0]
|
39 |
+
except TypeError:
|
40 |
+
seed = [seed]
|
41 |
+
self.batched = False
|
42 |
+
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def sort(a, b):
|
46 |
+
return (a, b, 1) if a < b else (b, a, -1)
|
47 |
+
|
48 |
+
def __call__(self, t0, t1):
|
49 |
+
t0, t1, sign = self.sort(t0, t1)
|
50 |
+
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
51 |
+
return w if self.batched else w[0]
|
52 |
+
|
53 |
+
|
54 |
+
class BrownianTreeNoiseSampler:
|
55 |
+
"""A noise sampler backed by a torchsde.BrownianTree.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
59 |
+
random samples.
|
60 |
+
sigma_min (float): The low end of the valid interval.
|
61 |
+
sigma_max (float): The high end of the valid interval.
|
62 |
+
seed (int or List[int]): The random seed. If a list of seeds is
|
63 |
+
supplied instead of a single integer, then the noise sampler will
|
64 |
+
use one BrownianTree per batch item, each with its own seed.
|
65 |
+
transform (callable): A function that maps sigma to the sampler's
|
66 |
+
internal timestep.
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
|
70 |
+
self.transform = transform
|
71 |
+
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
72 |
+
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
73 |
+
|
74 |
+
def __call__(self, sigma, sigma_next):
|
75 |
+
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
76 |
+
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
77 |
+
|
78 |
+
def append_dims(x, target_dims):
|
79 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
80 |
+
dims_to_append = target_dims - x.ndim
|
81 |
+
if dims_to_append < 0:
|
82 |
+
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
83 |
+
return x[(...,) + (None,) * dims_to_append]
|
84 |
+
|
85 |
+
|
86 |
+
def to_d(x, sigma, denoised):
|
87 |
+
"""Converts a denoiser output to a Karras ODE derivative."""
|
88 |
+
return (x - denoised) / append_dims(sigma, x.ndim)
|
89 |
+
|
90 |
+
|
91 |
+
@torch.no_grad()
|
92 |
+
def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
|
93 |
+
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
94 |
+
extra_args = {} if extra_args is None else extra_args
|
95 |
+
s_in = x.new_ones([x.shape[0]])
|
96 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
97 |
+
for u in range(repaint):
|
98 |
+
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
99 |
+
eps = torch.randn_like(x) * s_noise
|
100 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
101 |
+
if gamma > 0:
|
102 |
+
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
103 |
+
denoised = model(x, sigma_hat * s_in, **extra_args)
|
104 |
+
d = to_d(x, sigma_hat, denoised)
|
105 |
+
if callback is not None:
|
106 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
107 |
+
dt = sigmas[i + 1] - sigma_hat
|
108 |
+
# Euler method
|
109 |
+
x = x + d * dt
|
110 |
+
if x.isnan().any():
|
111 |
+
assert False, f"x has NaNs, i = {i}, u = {u}, repaint = {repaint}"
|
112 |
+
if u < repaint - 1:
|
113 |
+
beta = (sigmas[i + 1] / sigmas[-1]) ** 2
|
114 |
+
x = torch.sqrt(1 - beta) * x + torch.sqrt(beta) * torch.randn_like(x)
|
115 |
+
|
116 |
+
return x
|
117 |
+
|
118 |
+
def get_scalings(sigma, sigma_data=0.5):
|
119 |
+
c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2)
|
120 |
+
c_out = sigma * sigma_data / (sigma ** 2 + sigma_data ** 2) ** 0.5
|
121 |
+
c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
|
122 |
+
return c_skip, c_out, c_in
|
123 |
+
|
124 |
+
|
125 |
+
@torch.no_grad()
|
126 |
+
def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None,
|
127 |
+
disable=None, eta=1., s_noise=1., noise_sampler=None,
|
128 |
+
solver_type='midpoint',
|
129 |
+
repaint=4):
|
130 |
+
"""DPM-Solver++(2M) SDE. but with repaint added"""
|
131 |
+
|
132 |
+
if solver_type not in {'heun', 'midpoint'}:
|
133 |
+
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
134 |
+
|
135 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
136 |
+
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
137 |
+
extra_args = {} if extra_args is None else extra_args
|
138 |
+
s_in = x.new_ones([x.shape[0]])
|
139 |
+
|
140 |
+
old_denoised = None
|
141 |
+
h_last = None
|
142 |
+
old_x = None
|
143 |
+
|
144 |
+
for i in trange(len(sigmas) - 1, disable=disable): # time loop
|
145 |
+
|
146 |
+
for u in range(repaint):
|
147 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
148 |
+
if callback is not None:
|
149 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
150 |
+
#print("i, u, sigmas[i], sigmas[i + 1] = ", i, u, sigmas[i], sigmas[i + 1])
|
151 |
+
if sigmas[i + 1] == 0:
|
152 |
+
# Denoising step
|
153 |
+
x = denoised
|
154 |
+
else:
|
155 |
+
# DPM-Solver++(2M) SDE
|
156 |
+
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
157 |
+
h = s - t
|
158 |
+
eta_h = eta * h
|
159 |
+
|
160 |
+
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
161 |
+
|
162 |
+
if old_denoised is not None:
|
163 |
+
r = h_last / h
|
164 |
+
if solver_type == 'heun':
|
165 |
+
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
166 |
+
elif solver_type == 'midpoint':
|
167 |
+
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
168 |
+
|
169 |
+
if eta:
|
170 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
171 |
+
|
172 |
+
|
173 |
+
if callback is not None:
|
174 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
175 |
+
|
176 |
+
if x.isnan().any():
|
177 |
+
assert False, f"x has NaNs, i = {i}, u = {u}, repaint = {repaint}"
|
178 |
+
|
179 |
+
if u < repaint - 1:
|
180 |
+
# RePaint: go "back" in integration via the "forward" process, by adding a little noise to x
|
181 |
+
# ...but scaled properly!
|
182 |
+
# But how to convert from original RePaint to k-diffusion? I'll try a few variants
|
183 |
+
repaint_choice = 'orig' # ['orig','var1','var2', etc...]
|
184 |
+
|
185 |
+
sigma_diff = (sigmas[i] - sigmas[i+1]).abs()
|
186 |
+
sigma_ratio = ( sigmas[i+1] / sigma_max ) # use i+1 or i?
|
187 |
+
if repaint_choice == 'orig': # attempt at original RePaint algorithm, which used betas
|
188 |
+
# if sigmas are the std devs, then betas are variances? but beta_max = 1, so how to get that? ratio?
|
189 |
+
beta = sigma_ratio**2
|
190 |
+
x = torch.sqrt(1-beta)*x + torch.sqrt(beta)*torch.randn_like(x) # this is from RePaint Paper
|
191 |
+
elif repaint_choice == 'var1': # or maybe this...? # worse than orig
|
192 |
+
x = x + sigma_diff*torch.randn_like(x)
|
193 |
+
elif repaint_choice == 'var2': # or this...? # yields NaNs
|
194 |
+
x = (1-sigma_diff)*x + sigma_diff*torch.randn_like(x)
|
195 |
+
elif repaint_choice == 'var3': # results similar to var1
|
196 |
+
x = (1.0-sigma_ratio)*x + sigmas[i+1]*torch.randn_like(x)
|
197 |
+
elif repaint_choice == 'var4': # NaNs # stealing code from elsewhere, no idea WTF I'm doing.
|
198 |
+
#Invert this: target = (input - c_skip * noised_input) / c_out, where target = model_output
|
199 |
+
x_tm1, x_t = x, old_x
|
200 |
+
# x_tm1 = ( x_0 - c_skip * noised_x0 ) / c_out
|
201 |
+
# So x_tm1*c_out = x_0 - c_skip * noised_x0
|
202 |
+
input, noise = x_tm1, torch.randn_like(x)
|
203 |
+
noised_input = input + noise * append_dims(sigma_diff, input.ndim)
|
204 |
+
c_skip, c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigmas[i])]
|
205 |
+
model_output = x_tm1
|
206 |
+
renoised_x = c_out * model_output + c_skip * noised_input
|
207 |
+
x = renoised_x
|
208 |
+
elif repaint_choice == 'var5':
|
209 |
+
x = torch.sqrt((1-(sigma_diff/sigma_max)**2))*x + sigma_diff*torch.randn_like(x)
|
210 |
+
|
211 |
+
# include this? guessing no.
|
212 |
+
#old_denoised = denoised
|
213 |
+
#h_last = h
|
214 |
+
|
215 |
+
old_denoised = denoised
|
216 |
+
h_last = h
|
217 |
+
old_x = x
|
218 |
+
return x
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
# -----from stable-audio-tools
|
224 |
+
|
225 |
+
# Define the noise schedule and sampling loop
|
226 |
+
def get_alphas_sigmas(t):
|
227 |
+
"""Returns the scaling factors for the clean image (alpha) and for the
|
228 |
+
noise (sigma), given a timestep."""
|
229 |
+
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
|
230 |
+
|
231 |
+
def alpha_sigma_to_t(alpha, sigma):
|
232 |
+
"""Returns a timestep, given the scaling factors for the clean image and for
|
233 |
+
the noise."""
|
234 |
+
return torch.atan2(sigma, alpha) / math.pi * 2
|
235 |
+
|
236 |
+
def t_to_alpha_sigma(t):
|
237 |
+
"""Returns the scaling factors for the clean image and for the noise, given
|
238 |
+
a timestep."""
|
239 |
+
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
|
240 |
+
|
241 |
+
@torch.no_grad()
|
242 |
+
def sample(model, x, steps, eta, **extra_args):
|
243 |
+
"""Draws samples from a model given starting noise. v-diffusion"""
|
244 |
+
ts = x.new_ones([x.shape[0]])
|
245 |
+
|
246 |
+
# Create the noise schedule
|
247 |
+
t = torch.linspace(1, 0, steps + 1)[:-1]
|
248 |
+
|
249 |
+
alphas, sigmas = get_alphas_sigmas(t)
|
250 |
+
|
251 |
+
# The sampling loop
|
252 |
+
for i in trange(steps):
|
253 |
+
|
254 |
+
# Get the model output (v, the predicted velocity)
|
255 |
+
with torch.cuda.amp.autocast():
|
256 |
+
v = model(x, ts * t[i], **extra_args).float()
|
257 |
+
|
258 |
+
# Predict the noise and the denoised image
|
259 |
+
pred = x * alphas[i] - v * sigmas[i]
|
260 |
+
eps = x * sigmas[i] + v * alphas[i]
|
261 |
+
|
262 |
+
# If we are not on the last timestep, compute the noisy image for the
|
263 |
+
# next timestep.
|
264 |
+
if i < steps - 1:
|
265 |
+
# If eta > 0, adjust the scaling factor for the predicted noise
|
266 |
+
# downward according to the amount of additional noise to add
|
267 |
+
ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
|
268 |
+
(1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
|
269 |
+
adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
|
270 |
+
|
271 |
+
# Recombine the predicted noise and predicted denoised image in the
|
272 |
+
# correct proportions for the next step
|
273 |
+
x = pred * alphas[i + 1] + eps * adjusted_sigma
|
274 |
+
|
275 |
+
# Add the correct amount of fresh noise
|
276 |
+
if eta:
|
277 |
+
x += torch.randn_like(x) * ddim_sigma
|
278 |
+
|
279 |
+
# If we are on the last timestep, output the denoised image
|
280 |
+
return pred
|
281 |
+
|
282 |
+
# Soft mask inpainting is just shrinking hard (binary) mask inpainting
|
283 |
+
# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
|
284 |
+
def get_bmask(i, steps, mask):
|
285 |
+
strength = (i+1)/(steps)
|
286 |
+
# convert to binary mask
|
287 |
+
bmask = torch.where(mask<=strength,1,0)
|
288 |
+
return bmask
|
289 |
+
|
290 |
+
def make_cond_model_fn(model, cond_fn):
|
291 |
+
def cond_model_fn(x, sigma, **kwargs):
|
292 |
+
with torch.enable_grad():
|
293 |
+
x = x.detach().requires_grad_()
|
294 |
+
denoised = model(x, sigma, **kwargs)
|
295 |
+
cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
|
296 |
+
cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
|
297 |
+
return cond_denoised
|
298 |
+
return cond_model_fn
|
299 |
+
|
300 |
+
# Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
|
301 |
+
# init_data is init_audio as latents (if this is latent diffusion)
|
302 |
+
# For sampling, set both init_data and mask to None
|
303 |
+
# For variations, set init_data
|
304 |
+
# For inpainting, set both init_data & mask
|
305 |
+
def sample_k(
|
306 |
+
model_fn,
|
307 |
+
noise,
|
308 |
+
init_data=None,
|
309 |
+
mask=None,
|
310 |
+
steps=100,
|
311 |
+
sampler_type="dpmpp-2m-sde",
|
312 |
+
sigma_min=0.5,
|
313 |
+
sigma_max=50,
|
314 |
+
rho=1.0, device="cuda",
|
315 |
+
callback=None,
|
316 |
+
cond_fn=None,
|
317 |
+
model_config=None,
|
318 |
+
repaint=1,
|
319 |
+
**extra_args
|
320 |
+
):
|
321 |
+
|
322 |
+
#denoiser = K.external.VDenoiser(model_fn)
|
323 |
+
denoiser = K.Denoiser(model_fn, sigma_data=model_config['sigma_data'])
|
324 |
+
|
325 |
+
if cond_fn is not None:
|
326 |
+
denoiser = make_cond_model_fn(denoiser, cond_fn)
|
327 |
+
|
328 |
+
# Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
|
329 |
+
#sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
|
330 |
+
sigmas = K.sampling.get_sigmas_karras(steps, sigma_min, sigma_max, rho=7., device=device)
|
331 |
+
print("sigmas[0] = ", sigmas[0])
|
332 |
+
# Scale the initial noise by sigma
|
333 |
+
noise = noise * sigmas[0]
|
334 |
+
|
335 |
+
wrapped_callback = callback
|
336 |
+
|
337 |
+
if mask is None and init_data is not None:
|
338 |
+
# VARIATION (no inpainting)
|
339 |
+
# set the initial latent to the init_data, and noise it with initial sigma
|
340 |
+
x = init_data + noise
|
341 |
+
elif mask is not None and init_data is not None:
|
342 |
+
# INPAINTING
|
343 |
+
bmask = get_bmask(0, steps, mask)
|
344 |
+
# initial noising
|
345 |
+
input_noised = init_data + noise
|
346 |
+
# set the initial latent to a mix of init_data and noise, based on step 0's binary mask
|
347 |
+
x = input_noised * bmask + noise * (1-bmask)
|
348 |
+
# define the inpainting callback function (Note: side effects, it mutates x)
|
349 |
+
# See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
|
350 |
+
# callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
351 |
+
# This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
|
352 |
+
def inpainting_callback(args):
|
353 |
+
i = args["i"]
|
354 |
+
x = args["x"]
|
355 |
+
sigma = args["sigma"]
|
356 |
+
#denoised = args["denoised"]
|
357 |
+
# noise the init_data input with this step's appropriate amount of noise
|
358 |
+
input_noised = init_data + torch.randn_like(init_data) * sigma
|
359 |
+
# shrinking hard mask
|
360 |
+
bmask = get_bmask(i, steps, mask)
|
361 |
+
# mix input_noise with x, using binary mask
|
362 |
+
new_x = input_noised * bmask + x * (1-bmask)
|
363 |
+
# mutate x
|
364 |
+
x[:,:,:] = new_x[:,:,:]
|
365 |
+
# wrap together the inpainting callback and the user-submitted callback.
|
366 |
+
if callback is None:
|
367 |
+
wrapped_callback = inpainting_callback
|
368 |
+
else:
|
369 |
+
wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
|
370 |
+
else:
|
371 |
+
# SAMPLING
|
372 |
+
# set the initial latent to noise
|
373 |
+
x = noise
|
374 |
+
|
375 |
+
|
376 |
+
print("sample_k: x.min, x.max = ", x.min(), x.max())
|
377 |
+
print(f"sample_k: key, val.dtype = ",[ (key, val.dtype if val is not None else val) for key,val in extra_args.items()])
|
378 |
+
with torch.cuda.amp.autocast():
|
379 |
+
if sampler_type == "k-heun":
|
380 |
+
return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
381 |
+
elif sampler_type == "k-lms":
|
382 |
+
return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
383 |
+
elif sampler_type == "k-dpmpp-2s-ancestral":
|
384 |
+
return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
385 |
+
elif sampler_type == "k-dpm-2":
|
386 |
+
return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
387 |
+
elif sampler_type == "k-dpm-fast":
|
388 |
+
return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
389 |
+
elif sampler_type == "k-dpm-adaptive":
|
390 |
+
return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
391 |
+
elif sampler_type == "dpmpp-2m-sde":
|
392 |
+
return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
393 |
+
elif sampler_type == "my-dpmpp-2m-sde":
|
394 |
+
return my_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, repaint=repaint, extra_args=extra_args)
|
395 |
+
elif sampler_type == "dpmpp-3m-sde":
|
396 |
+
return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
397 |
+
elif sampler_type == "my-sample-euler":
|
398 |
+
return my_sample_euler(denoiser, x, sigmas, disable=False, callback=wrapped_callback, repaint=repaint, extra_args=extra_args)
|
399 |
+
|
400 |
+
|
401 |
+
## ---- end stable-audio-tools
|
402 |
+
def infer_mask_from_init_img(img, mask_with='white'):
|
403 |
+
"""given an image with mask areas marked, extract the mask itself
|
404 |
+
note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"""
|
405 |
+
print("Inferring mask from init_img")
|
406 |
+
assert mask_with in ['blue','white']
|
407 |
+
if not torch.is_tensor(img):
|
408 |
+
img = ToTensor()(img)
|
409 |
+
mask = torch.zeros(img.shape[-2:])
|
410 |
+
if mask_with == 'white':
|
411 |
+
mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
|
412 |
+
elif mask_with == 'blue':
|
413 |
+
mask[img[2,:,:]==1] = 1 # blue
|
414 |
+
return mask*1.0
|
415 |
+
|
416 |
+
def grow_mask(init_mask, grow_by=2):
|
417 |
+
"adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
|
418 |
+
new_mask = init_mask.clone()
|
419 |
+
for c in range(grow_by):
|
420 |
+
# wherever mask is bordered by a 1, set it to 1
|
421 |
+
new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0
|
422 |
+
return new_mask
|
423 |
+
|
424 |
+
|
425 |
+
def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
|
426 |
+
"adds extra noise inside mask"
|
427 |
+
init_mask = grow_mask(init_mask, grow_by=grow_by) # make the mask bigger
|
428 |
+
if not torch.is_tensor(init_image):
|
429 |
+
init_image = ToTensor()(init_image)
|
430 |
+
init_image = init_image.clone()
|
431 |
+
# wherever mask is 1, set first set init_image to min value
|
432 |
+
init_image[:,init_mask == 1] = init_image.min()
|
433 |
+
init_image = init_image + seed_scale*torch.randn_like(init_image) * (init_mask) # add noise where mask is 1
|
434 |
+
# wherever the mask is 1, set the blue channel to -1.0, otherwise leave it alone
|
435 |
+
init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
|
436 |
+
return init_image
|
437 |
+
|
438 |
+
|
439 |
+
def get_init_image_and_mask(args, device):
|
440 |
+
convert_tensor = transforms.ToTensor()
|
441 |
+
init_image = Image.open(args.init_image).convert('RGB')
|
442 |
+
init_image = convert_tensor(init_image)
|
443 |
+
#normalize image from 0..1 to -1..1
|
444 |
+
init_image = (2.0 * init_image) - 1.0
|
445 |
+
|
446 |
+
|
447 |
+
init_mask = torch.ones(init_image.shape[-2:]) # ones are where stuff will change, zeros will stay the same
|
448 |
+
|
449 |
+
inpaint_task = 'infer' # infer mask from init_image
|
450 |
+
assert inpaint_task in ['accomp','chords','melody','nucleation','notes','continue','infer']
|
451 |
+
|
452 |
+
if inpaint_task in ['melody','accomp']:
|
453 |
+
init_mask[0:70,:] = 0 # zero out a melody strip of image near top
|
454 |
+
init_mask[128+0:128+70,:] = 0 # zero out a melody strip of image along bottom row
|
455 |
+
if inpaint_task == 'melody':
|
456 |
+
init_mask = 1 - init_mask
|
457 |
+
elif inpaint_task in ['notes','chords']:
|
458 |
+
# keep chords only
|
459 |
+
#init_mask = torch.ones_like(x)
|
460 |
+
init_mask[0:CHORD_BORDER,:] = 0 # top row of 256x256
|
461 |
+
init_mask[128-CHORD_BORDER:128+CHORD_BORDER,:] = 0 # middle rows of 256x256
|
462 |
+
init_mask[-CHORD_BORDER:,:] = 0 # bottom row of 256x256
|
463 |
+
if inpaint_task == 'chords':
|
464 |
+
init_mask = 1 - init_mask # inverse: genereate chords given notes
|
465 |
+
elif inpaint_task == 'continue':
|
466 |
+
init_mask[0:128,:] = 0 # remember it's a square, so just mask out the bottom half
|
467 |
+
elif inpaint_task == 'nucleation':
|
468 |
+
# set mask to wherever the blue channel is >= 0.9
|
469 |
+
init_mask = (init_image[2,:,:] > 0.0)*1.0
|
470 |
+
# zero out init mask in top and bottom borders
|
471 |
+
init_mask[0:CHORD_BORDER,:] = 0
|
472 |
+
init_mask[-CHORD_BORDER:,:] = 0
|
473 |
+
init_mask[128-CHORD_BORDER:128+CHORD_BORDER,:] = 0
|
474 |
+
|
475 |
+
# remove all blue in init_image between the borders
|
476 |
+
init_image[2,CHORD_BORDER:128-CHORD_BORDER,:] = -1.0
|
477 |
+
init_image[2,128+CHORD_BORDER:-CHORD_BORDER,:] = -1.0
|
478 |
+
|
479 |
+
# grow the sides of the mask by one pixel:
|
480 |
+
# wherever mask is zero but is bordered by a 1, set it to 1
|
481 |
+
init_mask[1:-1,1:-1] = (init_mask[1:-1,1:-1] + init_mask[0:-2,1:-1] + init_mask[2:,1:-1] + init_mask[1:-1,0:-2] + init_mask[1:-1,2:]) > 0
|
482 |
+
#init_mask[1:-1,1:-1] = (init_mask[1:-1,1:-1] + init_mask[0:-2,1:-1] + init_mask[2:,1:-1] + init_mask[1:-1,0:-2] + init_mask[1:-1,2:]) > 0
|
483 |
+
elif inpaint_task == 'infer':
|
484 |
+
init_mask = infer_mask_from_init_img(init_image, mask_with='white')
|
485 |
+
|
486 |
+
# Also black out init_image wherever init mask is 1
|
487 |
+
init_image[:,init_mask == 1] = init_image.min()
|
488 |
+
|
489 |
+
if args.seed_scale > 0: # driving nucleation
|
490 |
+
print("Seeding nucleation, seed_scale = ", args.seed_scale)
|
491 |
+
init_image = add_seeding(init_image, init_mask, grow_by=0, seed_scale=args.seed_scale)
|
492 |
+
|
493 |
+
# remove any blue in middle of init image
|
494 |
+
print("init_image.shape = ", init_image.shape)
|
495 |
+
init_image[2,CHORD_BORDER:128-CHORD_BORDER,:] = -1.0
|
496 |
+
init_image[2,128+CHORD_BORDER:-CHORD_BORDER,:] = -1.0
|
497 |
+
|
498 |
+
# Debugging: output some images so we can see what's going on
|
499 |
+
init_mask_t = init_mask.float()*255 # convert mask to 0..255 for writing as image
|
500 |
+
# Convert to NumPy array and rearrange dimensions
|
501 |
+
init_mask_img_numpy = init_mask_t.byte().cpu().numpy()#.transpose(1, 2, 0)
|
502 |
+
init_mask_debug_img = Image.fromarray(init_mask_img_numpy)
|
503 |
+
init_mask_debug_img.save("init_mask_debug.png")
|
504 |
+
init_image_debug_img = Image.fromarray((init_image*127.5+127.5).byte().cpu().numpy().transpose(1,2,0))
|
505 |
+
init_image_debug_img.save("init_image_debug.png")
|
506 |
+
|
507 |
+
# reshape image and mask to be 4D tensors
|
508 |
+
init_image = init_image.unsqueeze(0).repeat(args.batch_size, 1, 1, 1)
|
509 |
+
init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
|
510 |
+
return init_image.to(device), init_mask.to(device)
|
511 |
+
|
512 |
+
|
513 |
+
def main():
|
514 |
+
global init_image, init_mask
|
515 |
+
p = argparse.ArgumentParser(description=__doc__,
|
516 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
517 |
+
p.add_argument('--batch-size', type=int, default=64,
|
518 |
+
help='the batch size')
|
519 |
+
p.add_argument('--checkpoint', type=Path, required=True,
|
520 |
+
help='the checkpoint to use')
|
521 |
+
p.add_argument('--config', type=Path,
|
522 |
+
help='the model config')
|
523 |
+
p.add_argument('-n', type=int, default=64,
|
524 |
+
help='the number of images to sample')
|
525 |
+
p.add_argument('--prefix', type=str, default='out',
|
526 |
+
help='the output prefix')
|
527 |
+
p.add_argument('--repaint', type=int, default=1,
|
528 |
+
help='number of (re)paint steps')
|
529 |
+
p.add_argument('--steps', type=int, default=50,
|
530 |
+
help='the number of denoising steps')
|
531 |
+
p.add_argument('--seed-scale', type=float, default=0.0, help='strength of nucleation seeding')
|
532 |
+
p.add_argument('--init-image', type=Path, default=None, help='the initial image')
|
533 |
+
p.add_argument('--init-strength', type=float, default=1., help='strength of init image')
|
534 |
+
args = p.parse_args()
|
535 |
+
print("args =", args, flush=True)
|
536 |
+
|
537 |
+
config = K.config.load_config(args.config if args.config else args.checkpoint)
|
538 |
+
model_config = config['model']
|
539 |
+
# TODO: allow non-square input sizes
|
540 |
+
assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1]
|
541 |
+
size = model_config['input_size']
|
542 |
+
|
543 |
+
accelerator = accelerate.Accelerator()
|
544 |
+
device = accelerator.device
|
545 |
+
print('Using device:', device, flush=True)
|
546 |
+
|
547 |
+
inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device)
|
548 |
+
cse = None # ChordSeqEncoder().eval().requires_grad_(False).to(device) # add chord embedding-maker to main model
|
549 |
+
if cse is not None:
|
550 |
+
inner_model.cse = cse
|
551 |
+
try:
|
552 |
+
inner_model.load_state_dict(safetorch.load_file(args.checkpoint))
|
553 |
+
except:
|
554 |
+
#ckpt = torch.load(args.checkpoint).to(device)
|
555 |
+
ckpt = torch.load(args.checkpoint, map_location='cpu')
|
556 |
+
inner_model.load_state_dict(ckpt['model'])
|
557 |
+
|
558 |
+
accelerator.print('Parameters:', K.utils.n_params(inner_model))
|
559 |
+
model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data'])
|
560 |
+
|
561 |
+
sigma_min = model_config['sigma_min']
|
562 |
+
sigma_max = model_config['sigma_max']
|
563 |
+
|
564 |
+
# SHH modified
|
565 |
+
torch.set_float32_matmul_precision('high')
|
566 |
+
#class_cond = torch.tensor([0]).to(device)
|
567 |
+
#num_classes = 10
|
568 |
+
#class_cond = torch.remainder(torch.arange(0, args.n), num_classes).int().to(device)
|
569 |
+
#extra_args = {'class_cond':class_cond}
|
570 |
+
extra_args = {}
|
571 |
+
init_image, init_mask = None, None
|
572 |
+
if args.init_image is not None:
|
573 |
+
init_image, init_mask = get_init_image_and_mask(args, device)
|
574 |
+
init_image = init_image.to(device)
|
575 |
+
init_mask = init_mask.to(device)
|
576 |
+
|
577 |
+
@torch.no_grad()
|
578 |
+
@K.utils.eval_mode(model)
|
579 |
+
def run():
|
580 |
+
global init_image, init_mask
|
581 |
+
if accelerator.is_local_main_process:
|
582 |
+
tqdm.write('Sampling...')
|
583 |
+
sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
|
584 |
+
|
585 |
+
#ddpm_sampler = DDPM(model)
|
586 |
+
#model_fn = model
|
587 |
+
#ddpm_sampler = K.external.VDenoiser(model_fn)
|
588 |
+
|
589 |
+
def sample_fn(n, debug=True):
|
590 |
+
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
|
591 |
+
print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
|
592 |
+
|
593 |
+
if args.init_image is not None:
|
594 |
+
init_data, mask = get_init_image_and_mask(args, device)
|
595 |
+
init_data = args.seed_scale*x*mask + (1-mask)*init_data # extra nucleation?
|
596 |
+
if cse is not None:
|
597 |
+
chord_cond = img_batch_to_seq_emb(init_data, inner_model.cse).to(device)
|
598 |
+
else:
|
599 |
+
chord_cond = None
|
600 |
+
#print("init_data.shape, init_data.min, init_data.max = ", init_data.shape, init_data.min(), init_data.max())
|
601 |
+
else:
|
602 |
+
init_data, mask, chord_cond = None, None, None
|
603 |
+
|
604 |
+
print("chord_cond = ", chord_cond)
|
605 |
+
extra_args['chord_cond'] = chord_cond
|
606 |
+
# these two work:
|
607 |
+
#x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)
|
608 |
+
#x_0 = K.sampling.sample_dpmpp_2m_sde(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)
|
609 |
+
|
610 |
+
noise = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device)
|
611 |
+
|
612 |
+
sampler_type="my-dpmpp-2m-sde" # "k-lms"
|
613 |
+
#sampler_type="my-sample-euler"
|
614 |
+
#sampler_type="dpmpp-2m-sde"
|
615 |
+
#sampler_type = "dpmpp-3m-sde"
|
616 |
+
#sampler_type = "k-dpmpp-2s-ancestral"
|
617 |
+
print("dtypes:", [x.dtype if x is not None else None for x in [noise, init_data, mask, chord_cond]])
|
618 |
+
x_0 = sample_k(inner_model, noise, sampler_type=sampler_type,
|
619 |
+
init_data=init_data, mask=mask, steps=args.steps,
|
620 |
+
sigma_min=sigma_min, sigma_max=sigma_max, rho=7.,
|
621 |
+
device=device, model_config=model_config, repaint=args.repaint,
|
622 |
+
**extra_args)
|
623 |
+
#x_0 = sample_k(inner_model, noise, sampler_type="dpmpp-2m-sde", steps=100, sigma_min=0.5, sigma_max=50, rho=1., device=device, model_config=model_config, **extra_args)
|
624 |
+
print("x_0.min, x_0.max = ", x_0.min(), x_0.max())
|
625 |
+
if x_0.isnan().any():
|
626 |
+
assert False, "x_0 has NaNs"
|
627 |
+
|
628 |
+
# do gpu garbage collection before proceeding
|
629 |
+
torch.cuda.empty_cache()
|
630 |
+
return x_0
|
631 |
+
|
632 |
+
x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size)
|
633 |
+
if accelerator.is_main_process:
|
634 |
+
for i, out in enumerate(x_0):
|
635 |
+
filename = f'{args.prefix}_{i:05}.png'
|
636 |
+
K.utils.to_pil_image(out).save(filename)
|
637 |
+
|
638 |
+
try:
|
639 |
+
run()
|
640 |
+
except KeyboardInterrupt:
|
641 |
+
pass
|
642 |
+
|
643 |
+
|
644 |
+
if __name__ == '__main__':
|
645 |
+
main()
|