import torch from sgm.modules.diffusionmodules.discretizer import Discretization class Img2ImgDiscretizationWrapper: """ wraps a discretizer, and prunes the sigmas params: strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) """ def __init__(self, discretization: Discretization, strength: float = 1.0): self.discretization = discretization self.strength = strength assert 0.0 <= self.strength <= 1.0 def __call__(self, *args, **kwargs): # sigmas start large first, and decrease then sigmas = self.discretization(*args, **kwargs) print(f"sigmas after discretization, before pruning img2img: ", sigmas) sigmas = torch.flip(sigmas, (0,)) sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] print("prune index:", max(int(self.strength * len(sigmas)), 1)) sigmas = torch.flip(sigmas, (0,)) print(f"sigmas after pruning: ", sigmas) return sigmas class Txt2NoisyDiscretizationWrapper: """ wraps a discretizer, and prunes the sigmas params: strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) """ def __init__( self, discretization: Discretization, strength: float = 0.0, original_steps=None ): self.discretization = discretization self.strength = strength self.original_steps = original_steps assert 0.0 <= self.strength <= 1.0 def __call__(self, *args, **kwargs): # sigmas start large first, and decrease then sigmas = self.discretization(*args, **kwargs) print(f"sigmas after discretization, before pruning img2img: ", sigmas) sigmas = torch.flip(sigmas, (0,)) if self.original_steps is None: steps = len(sigmas) else: steps = self.original_steps + 1 prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) sigmas = sigmas[prune_index:] print("prune index:", prune_index) sigmas = torch.flip(sigmas, (0,)) print(f"sigmas after pruning: ", sigmas) return sigmas