File size: 2,476 Bytes
35575bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from pathlib import Path
from re import S
from typing import List, Union
from diffusers import EulerDiscreteScheduler, StableDiffusionXLPipeline
from diffusers.loaders.lora import StableDiffusionXLLoraLoaderMixin
from torchvision.datasets.utils import download_url
class LightningMixin:
LORA_8_STEP_URL = "https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_8step_lora.safetensors"
__scheduler_old = None
__pipe: StableDiffusionXLPipeline = None
__scheduler = None
def configure_sdxl_lightning(self, pipe: StableDiffusionXLPipeline):
lora_path = Path.home() / ".cache" / "lora_8_step.safetensors"
download_url(self.LORA_8_STEP_URL, str(lora_path.parent), lora_path.name)
pipe.load_lora_weights(str(lora_path), adapter_name="8step_lora")
pipe.set_adapters([])
self.__scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
self.__scheduler_old = pipe.scheduler
self.__pipe = pipe
def enable_sdxl_lightning(self):
pipe = self.__pipe
pipe.scheduler = self.__scheduler
current = pipe.get_active_adapters()
current.extend(["8step_lora"])
weights = self.__find_adapter_weights(current)
pipe.set_adapters(current, adapter_weights=weights)
return {"guidance_scale": 0, "num_inference_steps": 8}
def disable_sdxl_lightning(self):
pipe = self.__pipe
pipe.scheduler = self.__scheduler_old
current = pipe.get_active_adapters()
current = [adapter for adapter in current if adapter != "8step_lora"]
weights = self.__find_adapter_weights(current)
pipe.set_adapters(current, adapter_weights=weights)
def __find_adapter_weights(self, names: List[str]):
pipe = self.__pipe
model = pipe.unet
from peft.tuners.tuners_utils import BaseTunerLayer
weights = []
for adapter_name in names:
weight = 1.0
for module in model.modules():
if isinstance(module, BaseTunerLayer):
if adapter_name in module.scaling:
weight = (
module.scaling[adapter_name]
* module.r[adapter_name]
/ module.lora_alpha[adapter_name]
)
weights.append(weight)
return weights
|