cocktailpeanut commited on
Commit
8611f7d
·
1 Parent(s): 41a3c74
stable_diffusion/ldm/models/diffusion/ddim.py CHANGED
@@ -4,10 +4,12 @@ import torch
4
  import numpy as np
5
  from tqdm import tqdm
6
  from functools import partial
 
7
 
8
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
9
  extract_into_tensor
10
 
 
11
 
12
  class DDIMSampler(object):
13
  def __init__(self, model, schedule="linear", **kwargs):
@@ -18,8 +20,8 @@ class DDIMSampler(object):
18
 
19
  def register_buffer(self, name, attr):
20
  if type(attr) == torch.Tensor:
21
- if attr.device != torch.device("cuda"):
22
- attr = attr.to(torch.device("cuda"))
23
  setattr(self, name, attr)
24
 
25
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
@@ -238,4 +240,4 @@ class DDIMSampler(object):
238
  x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
239
  unconditional_guidance_scale=unconditional_guidance_scale,
240
  unconditional_conditioning=unconditional_conditioning)
241
- return x_dec
 
4
  import numpy as np
5
  from tqdm import tqdm
6
  from functools import partial
7
+ import devicetorch
8
 
9
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
10
  extract_into_tensor
11
 
12
+ device = devicetorch.get(torch)
13
 
14
  class DDIMSampler(object):
15
  def __init__(self, model, schedule="linear", **kwargs):
 
20
 
21
  def register_buffer(self, name, attr):
22
  if type(attr) == torch.Tensor:
23
+ if attr.device != torch.device(device):
24
+ attr = attr.to(torch.device(device))
25
  setattr(self, name, attr)
26
 
27
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
 
240
  x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
241
  unconditional_guidance_scale=unconditional_guidance_scale,
242
  unconditional_conditioning=unconditional_conditioning)
243
+ return x_dec
stable_diffusion/ldm/models/diffusion/plms.py CHANGED
@@ -4,9 +4,11 @@ import torch
4
  import numpy as np
5
  from tqdm import tqdm
6
  from functools import partial
 
7
 
8
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
 
 
10
 
11
  class PLMSSampler(object):
12
  def __init__(self, model, schedule="linear", **kwargs):
@@ -17,8 +19,8 @@ class PLMSSampler(object):
17
 
18
  def register_buffer(self, name, attr):
19
  if type(attr) == torch.Tensor:
20
- if attr.device != torch.device("cuda"):
21
- attr = attr.to(torch.device("cuda"))
22
  setattr(self, name, attr)
23
 
24
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
 
4
  import numpy as np
5
  from tqdm import tqdm
6
  from functools import partial
7
+ import devicetorch
8
 
9
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
10
 
11
+ device = devicetorch.get(torch)
12
 
13
  class PLMSSampler(object):
14
  def __init__(self, model, schedule="linear", **kwargs):
 
19
 
20
  def register_buffer(self, name, attr):
21
  if type(attr) == torch.Tensor:
22
+ if attr.device != torch.device(device):
23
+ attr = attr.to(torch.device(device))
24
  setattr(self, name, attr)
25
 
26
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
stable_diffusion/notebook_helpers.py CHANGED
@@ -14,7 +14,9 @@ from ldm.models.diffusion.ddim import DDIMSampler
14
  from ldm.util import ismap
15
  import time
16
  from omegaconf import OmegaConf
 
17
 
 
18
 
19
  def download_models(mode):
20
 
@@ -44,7 +46,7 @@ def load_model_from_config(config, ckpt):
44
  sd = pl_sd["state_dict"]
45
  model = instantiate_from_config(config.model)
46
  m, u = model.load_state_dict(sd, strict=False)
47
- model.cuda()
48
  model.eval()
49
  return {"model": model}, global_step
50
 
@@ -117,7 +119,7 @@ def get_cond(mode, selected_path):
117
  c = rearrange(c, '1 c h w -> 1 h w c')
118
  c = 2. * c - 1.
119
 
120
- c = c.to(torch.device("cuda"))
121
  example["LR_image"] = c
122
  example["image"] = c_up
123
 
@@ -267,4 +269,4 @@ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, e
267
  log["sample"] = x_sample
268
  log["time"] = t1 - t0
269
 
270
- return log
 
14
  from ldm.util import ismap
15
  import time
16
  from omegaconf import OmegaConf
17
+ import devicetorch
18
 
19
+ device = devicetorch.get(torch)
20
 
21
  def download_models(mode):
22
 
 
46
  sd = pl_sd["state_dict"]
47
  model = instantiate_from_config(config.model)
48
  m, u = model.load_state_dict(sd, strict=False)
49
+ model.to(device)
50
  model.eval()
51
  return {"model": model}, global_step
52
 
 
119
  c = rearrange(c, '1 c h w -> 1 h w c')
120
  c = 2. * c - 1.
121
 
122
+ c = c.to(torch.device(device))
123
  example["LR_image"] = c
124
  example["image"] = c_up
125
 
 
269
  log["sample"] = x_sample
270
  log["time"] = t1 - t0
271
 
272
+ return log