Spaces:
Runtime error
Runtime error
cocktailpeanut
commited on
Commit
·
8611f7d
1
Parent(s):
41a3c74
update
Browse files
stable_diffusion/ldm/models/diffusion/ddim.py
CHANGED
@@ -4,10 +4,12 @@ import torch
|
|
4 |
import numpy as np
|
5 |
from tqdm import tqdm
|
6 |
from functools import partial
|
|
|
7 |
|
8 |
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
|
9 |
extract_into_tensor
|
10 |
|
|
|
11 |
|
12 |
class DDIMSampler(object):
|
13 |
def __init__(self, model, schedule="linear", **kwargs):
|
@@ -18,8 +20,8 @@ class DDIMSampler(object):
|
|
18 |
|
19 |
def register_buffer(self, name, attr):
|
20 |
if type(attr) == torch.Tensor:
|
21 |
-
if attr.device != torch.device(
|
22 |
-
attr = attr.to(torch.device(
|
23 |
setattr(self, name, attr)
|
24 |
|
25 |
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
@@ -238,4 +240,4 @@ class DDIMSampler(object):
|
|
238 |
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
239 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
240 |
unconditional_conditioning=unconditional_conditioning)
|
241 |
-
return x_dec
|
|
|
4 |
import numpy as np
|
5 |
from tqdm import tqdm
|
6 |
from functools import partial
|
7 |
+
import devicetorch
|
8 |
|
9 |
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
|
10 |
extract_into_tensor
|
11 |
|
12 |
+
device = devicetorch.get(torch)
|
13 |
|
14 |
class DDIMSampler(object):
|
15 |
def __init__(self, model, schedule="linear", **kwargs):
|
|
|
20 |
|
21 |
def register_buffer(self, name, attr):
|
22 |
if type(attr) == torch.Tensor:
|
23 |
+
if attr.device != torch.device(device):
|
24 |
+
attr = attr.to(torch.device(device))
|
25 |
setattr(self, name, attr)
|
26 |
|
27 |
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
|
|
240 |
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
241 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
242 |
unconditional_conditioning=unconditional_conditioning)
|
243 |
+
return x_dec
|
stable_diffusion/ldm/models/diffusion/plms.py
CHANGED
@@ -4,9 +4,11 @@ import torch
|
|
4 |
import numpy as np
|
5 |
from tqdm import tqdm
|
6 |
from functools import partial
|
|
|
7 |
|
8 |
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
9 |
|
|
|
10 |
|
11 |
class PLMSSampler(object):
|
12 |
def __init__(self, model, schedule="linear", **kwargs):
|
@@ -17,8 +19,8 @@ class PLMSSampler(object):
|
|
17 |
|
18 |
def register_buffer(self, name, attr):
|
19 |
if type(attr) == torch.Tensor:
|
20 |
-
if attr.device != torch.device(
|
21 |
-
attr = attr.to(torch.device(
|
22 |
setattr(self, name, attr)
|
23 |
|
24 |
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
|
|
4 |
import numpy as np
|
5 |
from tqdm import tqdm
|
6 |
from functools import partial
|
7 |
+
import devicetorch
|
8 |
|
9 |
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
10 |
|
11 |
+
device = devicetorch.get(torch)
|
12 |
|
13 |
class PLMSSampler(object):
|
14 |
def __init__(self, model, schedule="linear", **kwargs):
|
|
|
19 |
|
20 |
def register_buffer(self, name, attr):
|
21 |
if type(attr) == torch.Tensor:
|
22 |
+
if attr.device != torch.device(device):
|
23 |
+
attr = attr.to(torch.device(device))
|
24 |
setattr(self, name, attr)
|
25 |
|
26 |
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
stable_diffusion/notebook_helpers.py
CHANGED
@@ -14,7 +14,9 @@ from ldm.models.diffusion.ddim import DDIMSampler
|
|
14 |
from ldm.util import ismap
|
15 |
import time
|
16 |
from omegaconf import OmegaConf
|
|
|
17 |
|
|
|
18 |
|
19 |
def download_models(mode):
|
20 |
|
@@ -44,7 +46,7 @@ def load_model_from_config(config, ckpt):
|
|
44 |
sd = pl_sd["state_dict"]
|
45 |
model = instantiate_from_config(config.model)
|
46 |
m, u = model.load_state_dict(sd, strict=False)
|
47 |
-
model.
|
48 |
model.eval()
|
49 |
return {"model": model}, global_step
|
50 |
|
@@ -117,7 +119,7 @@ def get_cond(mode, selected_path):
|
|
117 |
c = rearrange(c, '1 c h w -> 1 h w c')
|
118 |
c = 2. * c - 1.
|
119 |
|
120 |
-
c = c.to(torch.device(
|
121 |
example["LR_image"] = c
|
122 |
example["image"] = c_up
|
123 |
|
@@ -267,4 +269,4 @@ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, e
|
|
267 |
log["sample"] = x_sample
|
268 |
log["time"] = t1 - t0
|
269 |
|
270 |
-
return log
|
|
|
14 |
from ldm.util import ismap
|
15 |
import time
|
16 |
from omegaconf import OmegaConf
|
17 |
+
import devicetorch
|
18 |
|
19 |
+
device = devicetorch.get(torch)
|
20 |
|
21 |
def download_models(mode):
|
22 |
|
|
|
46 |
sd = pl_sd["state_dict"]
|
47 |
model = instantiate_from_config(config.model)
|
48 |
m, u = model.load_state_dict(sd, strict=False)
|
49 |
+
model.to(device)
|
50 |
model.eval()
|
51 |
return {"model": model}, global_step
|
52 |
|
|
|
119 |
c = rearrange(c, '1 c h w -> 1 h w c')
|
120 |
c = 2. * c - 1.
|
121 |
|
122 |
+
c = c.to(torch.device(device))
|
123 |
example["LR_image"] = c
|
124 |
example["image"] = c_up
|
125 |
|
|
|
269 |
log["sample"] = x_sample
|
270 |
log["time"] = t1 - t0
|
271 |
|
272 |
+
return log
|