diff --git a/README.md b/README.md index 4f7c92f..e386624 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,12 @@ +# THIS IS A FORK + +Forked from https://github.com/crowsonkb/k-diffusion + +Changes: + +1. Add DPM++ 2M sampling fix by @hallatore https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457 +2. Add MPS fix for MacOS by @brkirch https://github.com/brkirch/k-diffusion + # k-diffusion An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch. The patching method in [Improving Diffusion Model Efficiency Through Patching](https://arxiv.org/abs/2207.04316) is implemented as well. diff --git a/k_diffusion/external.py b/k_diffusion/external.py index 79b51ce..b41d0eb 100644 --- a/k_diffusion/external.py +++ b/k_diffusion/external.py @@ -79,7 +79,9 @@ class DiscreteSchedule(nn.Module): def t_to_sigma(self, t): t = t.float() - low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + low_idx = t.floor().long() + high_idx = t.ceil().long() + w = t - low_idx if t.device.type == 'mps' else t.frac() log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] return log_sigma.exp() diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index f050f88..9f859d4 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -16,7 +16,7 @@ def append_zero(x): def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): """Constructs the noise schedule of Karras et al. (2022).""" - ramp = torch.linspace(0, 1, n) + ramp = torch.linspace(0, 1, n, device=device) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho @@ -400,7 +400,13 @@ class DPMSolver(nn.Module): for i in range(len(orders)): eps_cache = {} - t, t_next = ts[i], ts[i + 1] + + # MacOS fix + if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + t, t_next = ts[i].detach().clone(), ts[i + 1].detach().clone() + else: + t, t_next = ts[i], ts[i + 1] + if eta: sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta) t_next_ = torch.minimum(t_end, self.t(sd)) @@ -512,7 +518,12 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() - t_fn = lambda sigma: sigma.log().neg() + + # MacOS fix + if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + t_fn = lambda sigma: sigma.detach().clone().log().neg() + else: + t_fn = lambda sigma: sigma.log().neg() for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -547,7 +558,12 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() - t_fn = lambda sigma: sigma.log().neg() + + # MacOS fix + if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + t_fn = lambda sigma: sigma.detach().clone().log().neg() + else: + t_fn = lambda sigma: sigma.log().neg() for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -587,7 +603,13 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() - t_fn = lambda sigma: sigma.log().neg() + + # MacOS fix + if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + t_fn = lambda sigma: sigma.detach().clone().log().neg() + else: + t_fn = lambda sigma: sigma.log().neg() + old_denoised = None for i in trange(len(sigmas) - 1, disable=disable): @@ -596,12 +618,22 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) h = t_next - t + + t_min = min(sigma_fn(t_next), sigma_fn(t)) + t_max = max(sigma_fn(t_next), sigma_fn(t)) + if old_denoised is None or sigmas[i + 1] == 0: - x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised + x = (t_min / t_max) * x - (-h).expm1() * denoised else: h_last = t - t_fn(sigmas[i - 1]) - r = h_last / h + + h_min = min(h_last, h) + h_max = max(h_last, h) + r = h_max / h_min + + h_d = (h_max + h_min) / 2 denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised - x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d + x = (t_min / t_max) * x - (-h_d).expm1() * denoised_d + old_denoised = denoised return x diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py index 9afedb9..ce6014b 100644 --- a/k_diffusion/utils.py +++ b/k_diffusion/utils.py @@ -42,7 +42,10 @@ def append_dims(x, target_dims): dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') - return x[(...,) + (None,) * dims_to_append] + expanded = x[(...,) + (None,) * dims_to_append] + # MPS will get inf values if it tries to index into the new axes, but detaching fixes this. + # https://github.com/pytorch/pytorch/issues/84364 + return expanded.detach().clone() if expanded.device.type == 'mps' else expanded def n_params(module):