import os from typing import Type, Callable, TypeVar, Dict, Any import torch import diffusers from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPTextModelWithProjection class ENVStore: __DESERIALIZER: Dict[Type, Callable[[str,], Any]] = { bool: lambda x: bool(int(x)), int: int, str: lambda x: x, } __SERIALIZER: Dict[Type, Callable[[Any,], str]] = { bool: lambda x: str(int(x)), int: str, str: lambda x: x, } def __getattr__(self, name: str): value = os.environ.get(f"SDNEXT_OLIVE_{name}", None) if value is None: return ty = self.__class__.__annotations__[name] deserialize = self.__DESERIALIZER[ty] return deserialize(value) def __setattr__(self, name: str, value) -> None: if name not in self.__class__.__annotations__: return ty = self.__class__.__annotations__[name] serialize = self.__SERIALIZER[ty] os.environ[f"SDNEXT_OLIVE_{name}"] = serialize(value) def __delattr__(self, name: str) -> None: if name not in self.__class__.__annotations__: return key = f"SDNEXT_OLIVE_{name}" if key not in os.environ: return os.environ.pop(key) class OliveOptimizerConfig(ENVStore): from_diffusers_cache: bool is_sdxl: bool vae: str vae_sdxl_fp16_fix: bool width: int height: int batch_size: int cross_attention_dim: int time_ids_size: int config = OliveOptimizerConfig() def get_variant(): from modules.shared import opts if opts.diffusers_model_load_variant == 'default': from modules import devices if devices.dtype == torch.float16: return 'fp16' return None elif opts.diffusers_model_load_variant == 'fp32': return None else: return opts.diffusers_model_load_variant def get_loader_arguments(no_variant: bool = False): kwargs = {} if config.from_diffusers_cache: from modules.shared import opts kwargs["cache_dir"] = opts.diffusers_dir if not no_variant: kwargs["variant"] = get_variant() return kwargs T = TypeVar("T") def from_pretrained(cls: Type[T], pretrained_model_name_or_path: os.PathLike, *args, no_variant: bool = False, **kwargs) -> T: pretrained_model_name_or_path = str(pretrained_model_name_or_path) if pretrained_model_name_or_path.endswith(".onnx"): cls = diffusers.OnnxRuntimeModel pretrained_model_name_or_path = os.path.dirname(pretrained_model_name_or_path) return cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs, **get_loader_arguments(no_variant)) # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- # Helper latency-only dataloader that creates random tensors with no label class RandomDataLoader: def __init__(self, create_inputs_func, batchsize, torch_dtype): self.create_input_func = create_inputs_func self.batchsize = batchsize self.torch_dtype = torch_dtype def __getitem__(self, idx): label = None return self.create_input_func(self.batchsize, self.torch_dtype), label # ----------------------------------------------------------------------------- # TEXT ENCODER # ----------------------------------------------------------------------------- def text_encoder_inputs(batchsize, torch_dtype): input_ids = torch.zeros((config.batch_size, 77), dtype=torch_dtype) return { "input_ids": input_ids, "output_hidden_states": True, } if config.is_sdxl else input_ids def text_encoder_load(model_name): model = from_pretrained(CLIPTextModel, model_name, subfolder="text_encoder") return model def text_encoder_conversion_inputs(model): return text_encoder_inputs(1, torch.int32) def text_encoder_data_loader(data_dir, batchsize, *_, **__): return RandomDataLoader(text_encoder_inputs, config.batch_size, torch.int32) # ----------------------------------------------------------------------------- # TEXT ENCODER 2 # ----------------------------------------------------------------------------- def text_encoder_2_inputs(batchsize, torch_dtype): return { "input_ids": torch.zeros((config.batch_size, 77), dtype=torch_dtype), "output_hidden_states": True, } def text_encoder_2_load(model_name): model = from_pretrained(CLIPTextModelWithProjection, model_name, subfolder="text_encoder_2") return model def text_encoder_2_conversion_inputs(model): return text_encoder_2_inputs(1, torch.int64) def text_encoder_2_data_loader(data_dir, batchsize, *_, **__): return RandomDataLoader(text_encoder_2_inputs, config.batch_size, torch.int64) # ----------------------------------------------------------------------------- # UNET # ----------------------------------------------------------------------------- def unet_inputs(batchsize, torch_dtype, is_conversion_inputs=False): if config.is_sdxl: inputs = { "sample": torch.rand((2 * config.batch_size, 4, config.height // 8, config.width // 8), dtype=torch_dtype), "timestep": torch.rand((1,), dtype=torch_dtype), "encoder_hidden_states": torch.rand((2 * config.batch_size, 77, config.cross_attention_dim), dtype=torch_dtype), } if is_conversion_inputs: inputs["additional_inputs"] = { "added_cond_kwargs": { "text_embeds": torch.rand((2 * config.batch_size, 1280), dtype=torch_dtype), "time_ids": torch.rand((2 * config.batch_size, config.time_ids_size), dtype=torch_dtype), } } else: inputs["text_embeds"] = torch.rand((2 * config.batch_size, 1280), dtype=torch_dtype) inputs["time_ids"] = torch.rand((2 * config.batch_size, config.time_ids_size), dtype=torch_dtype) else: inputs = { "sample": torch.rand((config.batch_size, 4, config.height // 8, config.width // 8), dtype=torch_dtype), "timestep": torch.rand((config.batch_size,), dtype=torch_dtype), "encoder_hidden_states": torch.rand((config.batch_size, 77, config.cross_attention_dim), dtype=torch_dtype), } # use as kwargs since they won't be in the correct position if passed along with the tuple of inputs kwargs = { "return_dict": False, } if is_conversion_inputs: inputs["additional_inputs"] = { **kwargs, "added_cond_kwargs": { "text_embeds": torch.rand((1, 1280), dtype=torch_dtype), "time_ids": torch.rand((1, 5), dtype=torch_dtype), }, } else: inputs.update(kwargs) inputs["onnx::Concat_4"] = torch.rand((1, 1280), dtype=torch_dtype) inputs["onnx::Shape_5"] = torch.rand((1, 5), dtype=torch_dtype) return inputs def unet_load(model_name): model = from_pretrained(diffusers.UNet2DConditionModel, model_name, subfolder="unet") return model def unet_conversion_inputs(model): return tuple(unet_inputs(1, torch.float32, True).values()) def unet_data_loader(data_dir, batchsize, *_, **__): return RandomDataLoader(unet_inputs, config.batch_size, torch.float16) # ----------------------------------------------------------------------------- # VAE ENCODER # ----------------------------------------------------------------------------- def vae_encoder_inputs(batchsize, torch_dtype): return { "sample": torch.rand((config.batch_size, 3, config.height, config.width), dtype=torch_dtype), "return_dict": False, } def vae_encoder_load(model_name): subfolder = "vae_encoder" if os.path.isdir(os.path.join(model_name, "vae_encoder")) else "vae" if config.vae_sdxl_fp16_fix: model_name = "madebyollin/sdxl-vae-fp16-fix" subfolder = "" if config.vae is None: model = from_pretrained(diffusers.AutoencoderKL, model_name, subfolder=subfolder, no_variant=config.vae_sdxl_fp16_fix) else: model = diffusers.AutoencoderKL.from_single_file(config.vae) model.forward = lambda sample, return_dict: model.encode(sample, return_dict)[0].sample() return model def vae_encoder_conversion_inputs(model): return tuple(vae_encoder_inputs(1, torch.float32).values()) def vae_encoder_data_loader(data_dir, batchsize, *_, **__): return RandomDataLoader(vae_encoder_inputs, config.batch_size, torch.float16) # ----------------------------------------------------------------------------- # VAE DECODER # ----------------------------------------------------------------------------- def vae_decoder_inputs(batchsize, torch_dtype): return { "latent_sample": torch.rand((config.batch_size, 4, config.height // 8, config.width // 8), dtype=torch_dtype), "return_dict": False, } def vae_decoder_load(model_name): subfolder = "vae_decoder" if os.path.isdir(os.path.join(model_name, "vae_decoder")) else "vae" if config.vae_sdxl_fp16_fix: model_name = "madebyollin/sdxl-vae-fp16-fix" subfolder = "" if config.vae is None: model = from_pretrained(diffusers.AutoencoderKL, model_name, subfolder=subfolder, no_variant=config.vae_sdxl_fp16_fix) else: model = diffusers.AutoencoderKL.from_single_file(config.vae) model.forward = model.decode return model def vae_decoder_conversion_inputs(model): return tuple(vae_decoder_inputs(1, torch.float32).values()) def vae_decoder_data_loader(data_dir, batchsize, *_, **__): return RandomDataLoader(vae_decoder_inputs, config.batch_size, torch.float16)