Spaces:
Running
on
Zero
Running
on
Zero
drscotthawley
commited on
Commit
•
9411f2b
1
Parent(s):
6266660
removed zerogpu decorators
Browse files
sample.py
CHANGED
@@ -28,7 +28,7 @@ from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
|
|
28 |
# ---- my mangled sampler that includes repaint
|
29 |
import torchsde
|
30 |
|
31 |
-
|
32 |
class BatchedBrownianTree:
|
33 |
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
34 |
|
@@ -56,7 +56,7 @@ class BatchedBrownianTree:
|
|
56 |
return w if self.batched else w[0]
|
57 |
|
58 |
|
59 |
-
|
60 |
class BrownianTreeNoiseSampler:
|
61 |
"""A noise sampler backed by a torchsde.BrownianTree.
|
62 |
|
@@ -94,7 +94,7 @@ def to_d(x, sigma, denoised):
|
|
94 |
return (x - denoised) / append_dims(sigma, x.ndim)
|
95 |
|
96 |
|
97 |
-
|
98 |
@torch.no_grad()
|
99 |
def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
|
100 |
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
@@ -129,7 +129,7 @@ def get_scalings(sigma, sigma_data=0.5):
|
|
129 |
return c_skip, c_out, c_in
|
130 |
|
131 |
|
132 |
-
|
133 |
@torch.no_grad()
|
134 |
def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None,
|
135 |
disable=None, eta=1., s_noise=1., noise_sampler=None,
|
@@ -289,14 +289,14 @@ def sample(model, x, steps, eta, **extra_args):
|
|
289 |
|
290 |
# Soft mask inpainting is just shrinking hard (binary) mask inpainting
|
291 |
# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
|
292 |
-
|
293 |
def get_bmask(i, steps, mask):
|
294 |
strength = (i+1)/(steps)
|
295 |
# convert to binary mask
|
296 |
bmask = torch.where(mask<=strength,1,0)
|
297 |
return bmask
|
298 |
|
299 |
-
|
300 |
def make_cond_model_fn(model, cond_fn):
|
301 |
def cond_model_fn(x, sigma, **kwargs):
|
302 |
with torch.enable_grad():
|
@@ -312,7 +312,7 @@ def make_cond_model_fn(model, cond_fn):
|
|
312 |
# For sampling, set both init_data and mask to None
|
313 |
# For variations, set init_data
|
314 |
# For inpainting, set both init_data & mask
|
315 |
-
|
316 |
def sample_k(
|
317 |
model_fn,
|
318 |
noise,
|
@@ -410,7 +410,7 @@ def sample_k(
|
|
410 |
|
411 |
|
412 |
## ---- end stable-audio-tools
|
413 |
-
|
414 |
def infer_mask_from_init_img(img, mask_with='white'):
|
415 |
"""given an image with mask areas marked, extract the mask itself
|
416 |
note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"""
|
@@ -425,7 +425,7 @@ def infer_mask_from_init_img(img, mask_with='white'):
|
|
425 |
mask[img[2,:,:]==1] = 1 # blue
|
426 |
return mask*1.0
|
427 |
|
428 |
-
|
429 |
def grow_mask(init_mask, grow_by=2):
|
430 |
"adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
|
431 |
new_mask = init_mask.clone()
|
@@ -434,7 +434,7 @@ def grow_mask(init_mask, grow_by=2):
|
|
434 |
new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0
|
435 |
return new_mask
|
436 |
|
437 |
-
|
438 |
def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
|
439 |
"adds extra noise inside mask"
|
440 |
init_mask = grow_mask(init_mask, grow_by=grow_by) # make the mask bigger
|
@@ -448,7 +448,7 @@ def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
|
|
448 |
init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
|
449 |
return init_image
|
450 |
|
451 |
-
|
452 |
def get_init_image_and_mask(args, device):
|
453 |
convert_tensor = transforms.ToTensor()
|
454 |
init_image = Image.open(args.init_image).convert('RGB')
|
@@ -599,7 +599,7 @@ def main():
|
|
599 |
#model_fn = model
|
600 |
#ddpm_sampler = K.external.VDenoiser(model_fn)
|
601 |
|
602 |
-
|
603 |
def sample_fn(n, debug=True):
|
604 |
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
|
605 |
print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
|
|
|
28 |
# ---- my mangled sampler that includes repaint
|
29 |
import torchsde
|
30 |
|
31 |
+
#@spaces.GPU
|
32 |
class BatchedBrownianTree:
|
33 |
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
34 |
|
|
|
56 |
return w if self.batched else w[0]
|
57 |
|
58 |
|
59 |
+
#@spaces.GPU
|
60 |
class BrownianTreeNoiseSampler:
|
61 |
"""A noise sampler backed by a torchsde.BrownianTree.
|
62 |
|
|
|
94 |
return (x - denoised) / append_dims(sigma, x.ndim)
|
95 |
|
96 |
|
97 |
+
#@spaces.GPU
|
98 |
@torch.no_grad()
|
99 |
def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
|
100 |
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
|
|
129 |
return c_skip, c_out, c_in
|
130 |
|
131 |
|
132 |
+
#@spaces.GPU
|
133 |
@torch.no_grad()
|
134 |
def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None,
|
135 |
disable=None, eta=1., s_noise=1., noise_sampler=None,
|
|
|
289 |
|
290 |
# Soft mask inpainting is just shrinking hard (binary) mask inpainting
|
291 |
# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
|
292 |
+
#@spaces.GPU
|
293 |
def get_bmask(i, steps, mask):
|
294 |
strength = (i+1)/(steps)
|
295 |
# convert to binary mask
|
296 |
bmask = torch.where(mask<=strength,1,0)
|
297 |
return bmask
|
298 |
|
299 |
+
#@spaces.GPU
|
300 |
def make_cond_model_fn(model, cond_fn):
|
301 |
def cond_model_fn(x, sigma, **kwargs):
|
302 |
with torch.enable_grad():
|
|
|
312 |
# For sampling, set both init_data and mask to None
|
313 |
# For variations, set init_data
|
314 |
# For inpainting, set both init_data & mask
|
315 |
+
#@spaces.GPU
|
316 |
def sample_k(
|
317 |
model_fn,
|
318 |
noise,
|
|
|
410 |
|
411 |
|
412 |
## ---- end stable-audio-tools
|
413 |
+
#@spaces.GPU
|
414 |
def infer_mask_from_init_img(img, mask_with='white'):
|
415 |
"""given an image with mask areas marked, extract the mask itself
|
416 |
note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"""
|
|
|
425 |
mask[img[2,:,:]==1] = 1 # blue
|
426 |
return mask*1.0
|
427 |
|
428 |
+
#@spaces.GPU
|
429 |
def grow_mask(init_mask, grow_by=2):
|
430 |
"adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
|
431 |
new_mask = init_mask.clone()
|
|
|
434 |
new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0
|
435 |
return new_mask
|
436 |
|
437 |
+
#@spaces.GPU
|
438 |
def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
|
439 |
"adds extra noise inside mask"
|
440 |
init_mask = grow_mask(init_mask, grow_by=grow_by) # make the mask bigger
|
|
|
448 |
init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
|
449 |
return init_image
|
450 |
|
451 |
+
#@spaces.GPU
|
452 |
def get_init_image_and_mask(args, device):
|
453 |
convert_tensor = transforms.ToTensor()
|
454 |
init_image = Image.open(args.init_image).convert('RGB')
|
|
|
599 |
#model_fn = model
|
600 |
#ddpm_sampler = K.external.VDenoiser(model_fn)
|
601 |
|
602 |
+
#@spaces.GPU
|
603 |
def sample_fn(n, debug=True):
|
604 |
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
|
605 |
print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
|