File size: 6,180 Bytes
e47d403 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
diff --git a/README.md b/README.md
index 4f7c92f..e386624 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,12 @@
+# THIS IS A FORK
+
+Forked from https://github.com/crowsonkb/k-diffusion
+
+Changes:
+
+1. Add DPM++ 2M sampling fix by @hallatore https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
+2. Add MPS fix for MacOS by @brkirch https://github.com/brkirch/k-diffusion
+
# k-diffusion
An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch. The patching method in [Improving Diffusion Model Efficiency Through Patching](https://arxiv.org/abs/2207.04316) is implemented as well.
diff --git a/k_diffusion/external.py b/k_diffusion/external.py
index 79b51ce..b41d0eb 100644
--- a/k_diffusion/external.py
+++ b/k_diffusion/external.py
@@ -79,7 +79,9 @@ class DiscreteSchedule(nn.Module):
def t_to_sigma(self, t):
t = t.float()
- low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
+ low_idx = t.floor().long()
+ high_idx = t.ceil().long()
+ w = t - low_idx if t.device.type == 'mps' else t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py
index f050f88..9f859d4 100644
--- a/k_diffusion/sampling.py
+++ b/k_diffusion/sampling.py
@@ -16,7 +16,7 @@ def append_zero(x):
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
"""Constructs the noise schedule of Karras et al. (2022)."""
- ramp = torch.linspace(0, 1, n)
+ ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
@@ -400,7 +400,13 @@ class DPMSolver(nn.Module):
for i in range(len(orders)):
eps_cache = {}
- t, t_next = ts[i], ts[i + 1]
+
+ # MacOS fix
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+ t, t_next = ts[i].detach().clone(), ts[i + 1].detach().clone()
+ else:
+ t, t_next = ts[i], ts[i + 1]
+
if eta:
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
t_next_ = torch.minimum(t_end, self.t(sd))
@@ -512,7 +518,12 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
- t_fn = lambda sigma: sigma.log().neg()
+
+ # MacOS fix
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+ t_fn = lambda sigma: sigma.detach().clone().log().neg()
+ else:
+ t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -547,7 +558,12 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
- t_fn = lambda sigma: sigma.log().neg()
+
+ # MacOS fix
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+ t_fn = lambda sigma: sigma.detach().clone().log().neg()
+ else:
+ t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -587,7 +603,13 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
- t_fn = lambda sigma: sigma.log().neg()
+
+ # MacOS fix
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+ t_fn = lambda sigma: sigma.detach().clone().log().neg()
+ else:
+ t_fn = lambda sigma: sigma.log().neg()
+
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable):
@@ -596,12 +618,22 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
+
+ t_min = min(sigma_fn(t_next), sigma_fn(t))
+ t_max = max(sigma_fn(t_next), sigma_fn(t))
+
if old_denoised is None or sigmas[i + 1] == 0:
- x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
+ x = (t_min / t_max) * x - (-h).expm1() * denoised
else:
h_last = t - t_fn(sigmas[i - 1])
- r = h_last / h
+
+ h_min = min(h_last, h)
+ h_max = max(h_last, h)
+ r = h_max / h_min
+
+ h_d = (h_max + h_min) / 2
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
- x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
+ x = (t_min / t_max) * x - (-h_d).expm1() * denoised_d
+
old_denoised = denoised
return x
diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py
index 9afedb9..ce6014b 100644
--- a/k_diffusion/utils.py
+++ b/k_diffusion/utils.py
@@ -42,7 +42,10 @@ def append_dims(x, target_dims):
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
- return x[(...,) + (None,) * dims_to_append]
+ expanded = x[(...,) + (None,) * dims_to_append]
+ # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
+ # https://github.com/pytorch/pytorch/issues/84364
+ return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
def n_params(module):
|