File size: 2,476 Bytes
35575bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from pathlib import Path
from re import S
from typing import List, Union

from diffusers import EulerDiscreteScheduler, StableDiffusionXLPipeline
from diffusers.loaders.lora import StableDiffusionXLLoraLoaderMixin
from torchvision.datasets.utils import download_url


class LightningMixin:
    LORA_8_STEP_URL = "https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_8step_lora.safetensors"

    __scheduler_old = None
    __pipe: StableDiffusionXLPipeline = None
    __scheduler = None

    def configure_sdxl_lightning(self, pipe: StableDiffusionXLPipeline):
        lora_path = Path.home() / ".cache" / "lora_8_step.safetensors"

        download_url(self.LORA_8_STEP_URL, str(lora_path.parent), lora_path.name)

        pipe.load_lora_weights(str(lora_path), adapter_name="8step_lora")
        pipe.set_adapters([])

        self.__scheduler = EulerDiscreteScheduler.from_config(
            pipe.scheduler.config, timestep_spacing="trailing"
        )
        self.__scheduler_old = pipe.scheduler
        self.__pipe = pipe

    def enable_sdxl_lightning(self):
        pipe = self.__pipe
        pipe.scheduler = self.__scheduler

        current = pipe.get_active_adapters()
        current.extend(["8step_lora"])

        weights = self.__find_adapter_weights(current)
        pipe.set_adapters(current, adapter_weights=weights)

        return {"guidance_scale": 0, "num_inference_steps": 8}

    def disable_sdxl_lightning(self):
        pipe = self.__pipe
        pipe.scheduler = self.__scheduler_old

        current = pipe.get_active_adapters()
        current = [adapter for adapter in current if adapter != "8step_lora"]

        weights = self.__find_adapter_weights(current)
        pipe.set_adapters(current, adapter_weights=weights)

    def __find_adapter_weights(self, names: List[str]):
        pipe = self.__pipe

        model = pipe.unet

        from peft.tuners.tuners_utils import BaseTunerLayer

        weights = []
        for adapter_name in names:
            weight = 1.0
            for module in model.modules():
                if isinstance(module, BaseTunerLayer):
                    if adapter_name in module.scaling:
                        weight = (
                            module.scaling[adapter_name]
                            * module.r[adapter_name]
                            / module.lora_alpha[adapter_name]
                        )

            weights.append(weight)

        return weights