File size: 6,180 Bytes
e47d403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
diff --git a/README.md b/README.md
index 4f7c92f..e386624 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,12 @@
+# THIS IS A FORK
+
+Forked from https://github.com/crowsonkb/k-diffusion
+
+Changes:
+
+1. Add DPM++ 2M sampling fix by @hallatore https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
+2. Add MPS fix for MacOS by @brkirch https://github.com/brkirch/k-diffusion
+
 # k-diffusion
 
 An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch. The patching method in [Improving Diffusion Model Efficiency Through Patching](https://arxiv.org/abs/2207.04316) is implemented as well.
diff --git a/k_diffusion/external.py b/k_diffusion/external.py
index 79b51ce..b41d0eb 100644
--- a/k_diffusion/external.py
+++ b/k_diffusion/external.py
@@ -79,7 +79,9 @@ class DiscreteSchedule(nn.Module):
 
     def t_to_sigma(self, t):
         t = t.float()
-        low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
+        low_idx = t.floor().long()
+        high_idx = t.ceil().long()
+        w = t - low_idx if t.device.type == 'mps' else t.frac()
         log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
         return log_sigma.exp()
 
diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py
index f050f88..9f859d4 100644
--- a/k_diffusion/sampling.py
+++ b/k_diffusion/sampling.py
@@ -16,7 +16,7 @@ def append_zero(x):
 
 def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
     """Constructs the noise schedule of Karras et al. (2022)."""
-    ramp = torch.linspace(0, 1, n)
+    ramp = torch.linspace(0, 1, n, device=device)
     min_inv_rho = sigma_min ** (1 / rho)
     max_inv_rho = sigma_max ** (1 / rho)
     sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
@@ -400,7 +400,13 @@ class DPMSolver(nn.Module):
 
         for i in range(len(orders)):
             eps_cache = {}
-            t, t_next = ts[i], ts[i + 1]
+
+            # MacOS fix
+            if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+                t, t_next = ts[i].detach().clone(), ts[i + 1].detach().clone()
+            else:
+                t, t_next = ts[i], ts[i + 1]
+
             if eta:
                 sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
                 t_next_ = torch.minimum(t_end, self.t(sd))
@@ -512,7 +518,12 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
     noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
     s_in = x.new_ones([x.shape[0]])
     sigma_fn = lambda t: t.neg().exp()
-    t_fn = lambda sigma: sigma.log().neg()
+
+    # MacOS fix
+    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+        t_fn = lambda sigma: sigma.detach().clone().log().neg()
+    else:
+        t_fn = lambda sigma: sigma.log().neg()
 
     for i in trange(len(sigmas) - 1, disable=disable):
         denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -547,7 +558,12 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
     extra_args = {} if extra_args is None else extra_args
     s_in = x.new_ones([x.shape[0]])
     sigma_fn = lambda t: t.neg().exp()
-    t_fn = lambda sigma: sigma.log().neg()
+
+    # MacOS fix
+    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+        t_fn = lambda sigma: sigma.detach().clone().log().neg()
+    else:
+        t_fn = lambda sigma: sigma.log().neg()
 
     for i in trange(len(sigmas) - 1, disable=disable):
         denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -587,7 +603,13 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
     extra_args = {} if extra_args is None else extra_args
     s_in = x.new_ones([x.shape[0]])
     sigma_fn = lambda t: t.neg().exp()
-    t_fn = lambda sigma: sigma.log().neg()
+
+    # MacOS fix
+    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+        t_fn = lambda sigma: sigma.detach().clone().log().neg()
+    else:
+        t_fn = lambda sigma: sigma.log().neg()
+
     old_denoised = None
 
     for i in trange(len(sigmas) - 1, disable=disable):
@@ -596,12 +618,22 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
             callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
         t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
         h = t_next - t
+        
+        t_min = min(sigma_fn(t_next), sigma_fn(t))
+        t_max = max(sigma_fn(t_next), sigma_fn(t))
+
         if old_denoised is None or sigmas[i + 1] == 0:
-            x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
+            x = (t_min / t_max) * x - (-h).expm1() * denoised
         else:
             h_last = t - t_fn(sigmas[i - 1])
-            r = h_last / h
+
+            h_min = min(h_last, h)
+            h_max = max(h_last, h)
+            r = h_max / h_min
+
+            h_d = (h_max + h_min) / 2
             denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
-            x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
+            x = (t_min / t_max) * x - (-h_d).expm1() * denoised_d
+
         old_denoised = denoised
     return x
diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py
index 9afedb9..ce6014b 100644
--- a/k_diffusion/utils.py
+++ b/k_diffusion/utils.py
@@ -42,7 +42,10 @@ def append_dims(x, target_dims):
     dims_to_append = target_dims - x.ndim
     if dims_to_append < 0:
         raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
-    return x[(...,) + (None,) * dims_to_append]
+    expanded = x[(...,) + (None,) * dims_to_append]
+    # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
+    # https://github.com/pytorch/pytorch/issues/84364
+    return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
 
 
 def n_params(module):