cocktailpeanut commited on
Commit
517c053
β€’
1 Parent(s): bd5e995
customnet/ddim.py CHANGED
@@ -9,6 +9,12 @@ from einops import rearrange
9
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
10
  from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding
11
 
 
 
 
 
 
 
12
 
13
  class DDIMSampler(object):
14
  def __init__(self, model, schedule="linear", **kwargs):
@@ -27,9 +33,10 @@ class DDIMSampler(object):
27
 
28
 
29
  def register_buffer(self, name, attr):
30
- if type(attr) == torch.Tensor:
31
- if attr.device != torch.device("cuda"):
32
- attr = attr.to(torch.device("cuda"))
 
33
  setattr(self, name, attr)
34
 
35
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
 
9
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
10
  from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding
11
 
12
+ if torch.cuda.is_available():
13
+ _device = "cuda"
14
+ elif torch.backends.mps.is_available():
15
+ _device = "mps"
16
+ else:
17
+ _device = "cpu"
18
 
19
  class DDIMSampler(object):
20
  def __init__(self, model, schedule="linear", **kwargs):
 
33
 
34
 
35
  def register_buffer(self, name, attr):
36
+ if _device == "cuda":
37
+ if type(attr) == torch.Tensor:
38
+ if attr.device != torch.device("cuda"):
39
+ attr = attr.to(torch.device("cuda"))
40
  setattr(self, name, attr)
41
 
42
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
ldm/models/diffusion/ddim.py CHANGED
@@ -9,6 +9,12 @@ from einops import rearrange
9
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
10
  from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding
11
 
 
 
 
 
 
 
12
 
13
  class DDIMSampler(object):
14
  def __init__(self, model, schedule="linear", **kwargs):
@@ -27,9 +33,10 @@ class DDIMSampler(object):
27
 
28
 
29
  def register_buffer(self, name, attr):
30
- if type(attr) == torch.Tensor:
31
- if attr.device != torch.device("cuda"):
32
- attr = attr.to(torch.device("cuda"))
 
33
  setattr(self, name, attr)
34
 
35
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
@@ -322,4 +329,4 @@ class DDIMSampler(object):
322
  x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
323
  unconditional_guidance_scale=unconditional_guidance_scale,
324
  unconditional_conditioning=unconditional_conditioning)
325
- return x_dec
 
9
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
10
  from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding
11
 
12
+ if torch.cuda.is_available():
13
+ _device = "cuda"
14
+ elif torch.backends.mps.is_available():
15
+ _device = "mps"
16
+ else:
17
+ _device = "cpu"
18
 
19
  class DDIMSampler(object):
20
  def __init__(self, model, schedule="linear", **kwargs):
 
33
 
34
 
35
  def register_buffer(self, name, attr):
36
+ if _device == "cuda":
37
+ if type(attr) == torch.Tensor:
38
+ if attr.device != torch.device("cuda"):
39
+ attr = attr.to(torch.device("cuda"))
40
  setattr(self, name, attr)
41
 
42
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
 
329
  x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
330
  unconditional_guidance_scale=unconditional_guidance_scale,
331
  unconditional_conditioning=unconditional_conditioning)
332
+ return x_dec
ldm/models/diffusion/plms.py CHANGED
@@ -8,6 +8,12 @@ from functools import partial
8
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
  from ldm.models.diffusion.sampling_util import norm_thresholding
10
 
 
 
 
 
 
 
11
 
12
  class PLMSSampler(object):
13
  def __init__(self, model, schedule="linear", **kwargs):
@@ -17,9 +23,10 @@ class PLMSSampler(object):
17
  self.schedule = schedule
18
 
19
  def register_buffer(self, name, attr):
20
- if type(attr) == torch.Tensor:
21
- if attr.device != torch.device("cuda"):
22
- attr = attr.to(torch.device("cuda"))
 
23
  setattr(self, name, attr)
24
 
25
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
 
8
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
  from ldm.models.diffusion.sampling_util import norm_thresholding
10
 
11
+ if torch.cuda.is_available():
12
+ _device = "cuda"
13
+ elif torch.backends.mps.is_available():
14
+ _device = "mps"
15
+ else:
16
+ _device = "cpu"
17
 
18
  class PLMSSampler(object):
19
  def __init__(self, model, schedule="linear", **kwargs):
 
23
  self.schedule = schedule
24
 
25
  def register_buffer(self, name, attr):
26
+ if _device == "cuda":
27
+ if type(attr) == torch.Tensor:
28
+ if attr.device != torch.device("cuda"):
29
+ attr = attr.to(torch.device("cuda"))
30
  setattr(self, name, attr)
31
 
32
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):