Spaces:
Sleeping
Sleeping
drscotthawley
commited on
Commit
•
6873531
1
Parent(s):
500319a
adding zerogpu decorators
Browse files
sample.py
CHANGED
@@ -1,9 +1,13 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
|
3 |
# Code by Kat Crowson in k-diffusion repo, modified by Scott H Hawley (SHH)
|
|
|
4 |
|
5 |
"""Samples from k-diffusion models."""
|
6 |
|
|
|
|
|
|
|
7 |
import argparse
|
8 |
from pathlib import Path
|
9 |
|
@@ -24,6 +28,7 @@ from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
|
|
24 |
# ---- my mangled sampler that includes repaint
|
25 |
import torchsde
|
26 |
|
|
|
27 |
class BatchedBrownianTree:
|
28 |
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
29 |
|
@@ -51,6 +56,7 @@ class BatchedBrownianTree:
|
|
51 |
return w if self.batched else w[0]
|
52 |
|
53 |
|
|
|
54 |
class BrownianTreeNoiseSampler:
|
55 |
"""A noise sampler backed by a torchsde.BrownianTree.
|
56 |
|
@@ -88,6 +94,7 @@ def to_d(x, sigma, denoised):
|
|
88 |
return (x - denoised) / append_dims(sigma, x.ndim)
|
89 |
|
90 |
|
|
|
91 |
@torch.no_grad()
|
92 |
def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
|
93 |
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
@@ -122,6 +129,7 @@ def get_scalings(sigma, sigma_data=0.5):
|
|
122 |
return c_skip, c_out, c_in
|
123 |
|
124 |
|
|
|
125 |
@torch.no_grad()
|
126 |
def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None,
|
127 |
disable=None, eta=1., s_noise=1., noise_sampler=None,
|
@@ -281,12 +289,14 @@ def sample(model, x, steps, eta, **extra_args):
|
|
281 |
|
282 |
# Soft mask inpainting is just shrinking hard (binary) mask inpainting
|
283 |
# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
|
|
|
284 |
def get_bmask(i, steps, mask):
|
285 |
strength = (i+1)/(steps)
|
286 |
# convert to binary mask
|
287 |
bmask = torch.where(mask<=strength,1,0)
|
288 |
return bmask
|
289 |
|
|
|
290 |
def make_cond_model_fn(model, cond_fn):
|
291 |
def cond_model_fn(x, sigma, **kwargs):
|
292 |
with torch.enable_grad():
|
@@ -302,6 +312,7 @@ def make_cond_model_fn(model, cond_fn):
|
|
302 |
# For sampling, set both init_data and mask to None
|
303 |
# For variations, set init_data
|
304 |
# For inpainting, set both init_data & mask
|
|
|
305 |
def sample_k(
|
306 |
model_fn,
|
307 |
noise,
|
@@ -399,6 +410,7 @@ def sample_k(
|
|
399 |
|
400 |
|
401 |
## ---- end stable-audio-tools
|
|
|
402 |
def infer_mask_from_init_img(img, mask_with='white'):
|
403 |
"""given an image with mask areas marked, extract the mask itself
|
404 |
note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"""
|
@@ -413,6 +425,7 @@ def infer_mask_from_init_img(img, mask_with='white'):
|
|
413 |
mask[img[2,:,:]==1] = 1 # blue
|
414 |
return mask*1.0
|
415 |
|
|
|
416 |
def grow_mask(init_mask, grow_by=2):
|
417 |
"adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
|
418 |
new_mask = init_mask.clone()
|
@@ -421,7 +434,7 @@ def grow_mask(init_mask, grow_by=2):
|
|
421 |
new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0
|
422 |
return new_mask
|
423 |
|
424 |
-
|
425 |
def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
|
426 |
"adds extra noise inside mask"
|
427 |
init_mask = grow_mask(init_mask, grow_by=grow_by) # make the mask bigger
|
@@ -435,7 +448,7 @@ def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
|
|
435 |
init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
|
436 |
return init_image
|
437 |
|
438 |
-
|
439 |
def get_init_image_and_mask(args, device):
|
440 |
convert_tensor = transforms.ToTensor()
|
441 |
init_image = Image.open(args.init_image).convert('RGB')
|
@@ -509,7 +522,7 @@ def get_init_image_and_mask(args, device):
|
|
509 |
init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
|
510 |
return init_image.to(device), init_mask.to(device)
|
511 |
|
512 |
-
|
513 |
def main():
|
514 |
global init_image, init_mask
|
515 |
p = argparse.ArgumentParser(description=__doc__,
|
@@ -586,6 +599,7 @@ def main():
|
|
586 |
#model_fn = model
|
587 |
#ddpm_sampler = K.external.VDenoiser(model_fn)
|
588 |
|
|
|
589 |
def sample_fn(n, debug=True):
|
590 |
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
|
591 |
print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
|
3 |
# Code by Kat Crowson in k-diffusion repo, modified by Scott H Hawley (SHH)
|
4 |
+
# Modified by Scott H. Hawley for masking, ZeroGPU ets.
|
5 |
|
6 |
"""Samples from k-diffusion models."""
|
7 |
|
8 |
+
import gradio
|
9 |
+
import spaces
|
10 |
+
import natten
|
11 |
import argparse
|
12 |
from pathlib import Path
|
13 |
|
|
|
28 |
# ---- my mangled sampler that includes repaint
|
29 |
import torchsde
|
30 |
|
31 |
+
@spaces.GPU
|
32 |
class BatchedBrownianTree:
|
33 |
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
34 |
|
|
|
56 |
return w if self.batched else w[0]
|
57 |
|
58 |
|
59 |
+
@spaces.GPU
|
60 |
class BrownianTreeNoiseSampler:
|
61 |
"""A noise sampler backed by a torchsde.BrownianTree.
|
62 |
|
|
|
94 |
return (x - denoised) / append_dims(sigma, x.ndim)
|
95 |
|
96 |
|
97 |
+
@spaces.GPU
|
98 |
@torch.no_grad()
|
99 |
def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
|
100 |
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
|
|
129 |
return c_skip, c_out, c_in
|
130 |
|
131 |
|
132 |
+
@spaces.GPU
|
133 |
@torch.no_grad()
|
134 |
def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None,
|
135 |
disable=None, eta=1., s_noise=1., noise_sampler=None,
|
|
|
289 |
|
290 |
# Soft mask inpainting is just shrinking hard (binary) mask inpainting
|
291 |
# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
|
292 |
+
@spaces.GPU
|
293 |
def get_bmask(i, steps, mask):
|
294 |
strength = (i+1)/(steps)
|
295 |
# convert to binary mask
|
296 |
bmask = torch.where(mask<=strength,1,0)
|
297 |
return bmask
|
298 |
|
299 |
+
@spaces.GPU
|
300 |
def make_cond_model_fn(model, cond_fn):
|
301 |
def cond_model_fn(x, sigma, **kwargs):
|
302 |
with torch.enable_grad():
|
|
|
312 |
# For sampling, set both init_data and mask to None
|
313 |
# For variations, set init_data
|
314 |
# For inpainting, set both init_data & mask
|
315 |
+
@spaces.GPU
|
316 |
def sample_k(
|
317 |
model_fn,
|
318 |
noise,
|
|
|
410 |
|
411 |
|
412 |
## ---- end stable-audio-tools
|
413 |
+
@spaces.GPU
|
414 |
def infer_mask_from_init_img(img, mask_with='white'):
|
415 |
"""given an image with mask areas marked, extract the mask itself
|
416 |
note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"""
|
|
|
425 |
mask[img[2,:,:]==1] = 1 # blue
|
426 |
return mask*1.0
|
427 |
|
428 |
+
@spaces.GPU
|
429 |
def grow_mask(init_mask, grow_by=2):
|
430 |
"adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
|
431 |
new_mask = init_mask.clone()
|
|
|
434 |
new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0
|
435 |
return new_mask
|
436 |
|
437 |
+
@spaces.GPU
|
438 |
def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
|
439 |
"adds extra noise inside mask"
|
440 |
init_mask = grow_mask(init_mask, grow_by=grow_by) # make the mask bigger
|
|
|
448 |
init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
|
449 |
return init_image
|
450 |
|
451 |
+
@spaces.GPU
|
452 |
def get_init_image_and_mask(args, device):
|
453 |
convert_tensor = transforms.ToTensor()
|
454 |
init_image = Image.open(args.init_image).convert('RGB')
|
|
|
522 |
init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
|
523 |
return init_image.to(device), init_mask.to(device)
|
524 |
|
525 |
+
@spaces.GPU
|
526 |
def main():
|
527 |
global init_image, init_mask
|
528 |
p = argparse.ArgumentParser(description=__doc__,
|
|
|
599 |
#model_fn = model
|
600 |
#ddpm_sampler = K.external.VDenoiser(model_fn)
|
601 |
|
602 |
+
@spaces.GPU
|
603 |
def sample_fn(n, debug=True):
|
604 |
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
|
605 |
print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
|