test / modules /onnx_impl /__init__.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
import os
from typing import Any, Dict, Callable, Optional
import numpy as np
import torch
import diffusers
import onnxruntime as ort
import optimum.onnxruntime
initialized = False
run_olive_workflow = None
class DynamicSessionOptions(ort.SessionOptions):
config: Optional[Dict] = None
def __init__(self):
super().__init__()
self.enable_mem_pattern = False
@classmethod
def from_sess_options(cls, sess_options: ort.SessionOptions):
if isinstance(sess_options, DynamicSessionOptions):
return sess_options.copy()
return DynamicSessionOptions()
def enable_static_dims(self, config: Dict):
self.config = config
self.add_free_dimension_override_by_name("unet_sample_batch", config["hidden_batch_size"])
self.add_free_dimension_override_by_name("unet_sample_channels", 4)
self.add_free_dimension_override_by_name("unet_sample_height", config["height"] // 8)
self.add_free_dimension_override_by_name("unet_sample_width", config["width"] // 8)
self.add_free_dimension_override_by_name("unet_time_batch", 1)
self.add_free_dimension_override_by_name("unet_hidden_batch", config["hidden_batch_size"])
self.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
if config["is_sdxl"] and not config["is_refiner"]:
self.add_free_dimension_override_by_name("unet_text_embeds_batch", config["hidden_batch_size"])
self.add_free_dimension_override_by_name("unet_text_embeds_size", 1280)
self.add_free_dimension_override_by_name("unet_time_ids_batch", config["hidden_batch_size"])
self.add_free_dimension_override_by_name("unet_time_ids_size", 6)
def copy(self):
sess_options = DynamicSessionOptions()
if self.config is not None:
sess_options.enable_static_dims(self.config)
return sess_options
class TorchCompatibleModule:
device = torch.device("cpu")
dtype = torch.float32
def to(self, *_, **__):
raise NotImplementedError
def type(self, *_, **__):
return self
class TemporalModule(TorchCompatibleModule):
"""
Replace the models which are not able to be moved to CPU.
"""
provider: Any
path: str
sess_options: ort.SessionOptions
def __init__(self, provider: Any, path: str, sess_options: ort.SessionOptions):
self.provider = provider
self.path = path
self.sess_options = sess_options
def to(self, *args, **kwargs):
from .utils import extract_device
device = extract_device(args, kwargs)
if device is not None and device.type != "cpu":
from .execution_providers import TORCH_DEVICE_TO_EP
provider = TORCH_DEVICE_TO_EP[device.type] if device.type in TORCH_DEVICE_TO_EP else self.provider
return OnnxRuntimeModel.load_model(self.path, provider, DynamicSessionOptions.from_sess_options(self.sess_options))
return self
class OnnxRuntimeModel(TorchCompatibleModule, diffusers.OnnxRuntimeModel):
config = {} # dummy
def named_modules(self): # dummy
return ()
def to(self, *args, **kwargs):
from modules.onnx_impl.utils import extract_device, move_inference_session
device = extract_device(args, kwargs)
if device is not None:
self.device = device
self.model = move_inference_session(self.model, device)
return self
class VAEConfig:
DEFAULTS = { "scaling_factor": 0.18215 }
config: Dict
def __init__(self, config: Dict):
self.config = config
def __getattr__(self, key):
return self.config.get(key, VAEConfig.DEFAULTS[key])
class VAE(TorchCompatibleModule):
pipeline: Any
def __init__(self, pipeline: Any):
self.pipeline = pipeline
@property
def config(self):
return VAEConfig(self.pipeline.vae_decoder.config)
@property
def device(self):
return self.pipeline.vae_decoder.device
def encode(self, sample: torch.Tensor, *_, **__):
sample_np = sample.cpu().numpy()
return [
torch.from_numpy(np.concatenate(
[self.pipeline.vae_encoder(sample=sample_np[i : i + 1])[0] for i in range(sample_np.shape[0])]
)).to(sample.device)
]
def decode(self, latent_sample: torch.Tensor, *_, **__):
latents_np = latent_sample.cpu().numpy()
return [
torch.from_numpy(np.concatenate(
[self.pipeline.vae_decoder(latent_sample=latents_np[i : i + 1])[0] for i in range(latents_np.shape[0])]
)).to(latent_sample.device)
]
def to(self, *args, **kwargs):
self.pipeline.vae_encoder = self.pipeline.vae_encoder.to(*args, **kwargs)
self.pipeline.vae_decoder = self.pipeline.vae_decoder.to(*args, **kwargs)
return self
def check_parameters_changed(p, refiner_enabled: bool):
from modules import shared, sd_models
if shared.sd_model.__class__.__name__ == "OnnxRawPipeline" or not shared.sd_model.__class__.__name__.startswith("Onnx"):
return shared.sd_model
compile_height = p.height
compile_width = p.width
if (shared.compiled_model_state is None or
shared.compiled_model_state.height != compile_height
or shared.compiled_model_state.width != compile_width
or shared.compiled_model_state.batch_size != p.batch_size):
shared.log.info("Olive: Parameter change detected")
shared.log.info("Olive: Recompiling base model")
sd_models.unload_model_weights(op='model')
sd_models.reload_model_weights(op='model')
if refiner_enabled:
shared.log.info("Olive: Recompiling refiner")
sd_models.unload_model_weights(op='refiner')
sd_models.reload_model_weights(op='refiner')
shared.compiled_model_state.height = compile_height
shared.compiled_model_state.width = compile_width
shared.compiled_model_state.batch_size = p.batch_size
return shared.sd_model
def preprocess_pipeline(p):
from modules import shared, sd_models
if "ONNX" not in shared.opts.diffusers_pipeline:
shared.log.warning(f"Unsupported pipeline for 'olive-ai' compile backend: {shared.opts.diffusers_pipeline}. You should select one of the ONNX pipelines.")
return shared.sd_model
if hasattr(shared.sd_model, "preprocess"):
shared.sd_model = shared.sd_model.preprocess(p)
if hasattr(shared.sd_refiner, "preprocess"):
if shared.opts.onnx_unload_base:
sd_models.unload_model_weights(op='model')
shared.sd_refiner = shared.sd_refiner.preprocess(p)
if shared.opts.onnx_unload_base:
sd_models.reload_model_weights(op='model')
shared.sd_model = shared.sd_model.preprocess(p)
return shared.sd_model
def ORTDiffusionModelPart_to(self, *args, **kwargs):
self.parent_model = self.parent_model.to(*args, **kwargs)
return self
def initialize_onnx():
global initialized # pylint: disable=global-statement
if initialized:
return
from installer import log, installed
from modules import devices
from modules.shared import opts
if not installed('onnx', quiet=True):
return
try: # may fail on onnx import
import onnx # pylint: disable=unused-import
from .execution_providers import ExecutionProvider, TORCH_DEVICE_TO_EP, available_execution_providers
if devices.backend == "rocm":
TORCH_DEVICE_TO_EP["cuda"] = ExecutionProvider.ROCm
from .pipelines.onnx_stable_diffusion_pipeline import OnnxStableDiffusionPipeline
from .pipelines.onnx_stable_diffusion_img2img_pipeline import OnnxStableDiffusionImg2ImgPipeline
from .pipelines.onnx_stable_diffusion_inpaint_pipeline import OnnxStableDiffusionInpaintPipeline
from .pipelines.onnx_stable_diffusion_upscale_pipeline import OnnxStableDiffusionUpscalePipeline
from .pipelines.onnx_stable_diffusion_xl_pipeline import OnnxStableDiffusionXLPipeline
from .pipelines.onnx_stable_diffusion_xl_img2img_pipeline import OnnxStableDiffusionXLImg2ImgPipeline
OnnxRuntimeModel.__module__ = 'diffusers' # OnnxRuntimeModel Hijack.
diffusers.OnnxRuntimeModel = OnnxRuntimeModel
diffusers.OnnxStableDiffusionPipeline = OnnxStableDiffusionPipeline
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["onnx-stable-diffusion"] = diffusers.OnnxStableDiffusionPipeline
diffusers.OnnxStableDiffusionImg2ImgPipeline = OnnxStableDiffusionImg2ImgPipeline
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["onnx-stable-diffusion"] = diffusers.OnnxStableDiffusionImg2ImgPipeline
diffusers.OnnxStableDiffusionInpaintPipeline = OnnxStableDiffusionInpaintPipeline
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["onnx-stable-diffusion"] = diffusers.OnnxStableDiffusionInpaintPipeline
diffusers.OnnxStableDiffusionUpscalePipeline = OnnxStableDiffusionUpscalePipeline
diffusers.OnnxStableDiffusionXLPipeline = OnnxStableDiffusionXLPipeline
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["onnx-stable-diffusion-xl"] = diffusers.OnnxStableDiffusionXLPipeline
diffusers.OnnxStableDiffusionXLImg2ImgPipeline = OnnxStableDiffusionXLImg2ImgPipeline
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["onnx-stable-diffusion-xl"] = diffusers.OnnxStableDiffusionXLImg2ImgPipeline
diffusers.ORTStableDiffusionXLPipeline = diffusers.OnnxStableDiffusionXLPipeline # Huggingface model compatibility
diffusers.ORTStableDiffusionXLImg2ImgPipeline = diffusers.OnnxStableDiffusionXLImg2ImgPipeline
optimum.onnxruntime.modeling_diffusion._ORTDiffusionModelPart.to = ORTDiffusionModelPart_to # pylint: disable=protected-access
log.debug(f'ONNX: version={ort.__version__} provider={opts.onnx_execution_provider}, available={available_execution_providers}')
except Exception as e:
log.error(f'ONNX failed to initialize: {e}')
initialized = True
def initialize_olive():
global run_olive_workflow # pylint: disable=global-statement
from installer import installed, log
if not installed('olive-ai', quiet=True) or not installed('onnx', quiet=True):
return
import sys
import importlib
orig_sys_path = sys.path
venv_dir = os.environ.get("VENV_DIR", os.path.join(os.getcwd(), 'venv'))
try:
spec = importlib.util.find_spec('onnxruntime.transformers')
sys.path = [d for d in spec.submodule_search_locations + sys.path if sys.path[1] not in d or venv_dir in d]
from onnxruntime.transformers import convert_generation # pylint: disable=unused-import
spec = importlib.util.find_spec('olive')
sys.path = spec.submodule_search_locations + sys.path
run_olive_workflow = importlib.import_module('olive.workflows').run
except Exception as e:
run_olive_workflow = None
log.error(f'Olive: Failed to load olive-ai: {e}')
sys.path = orig_sys_path
def install_olive():
from installer import installed, install, log
if installed("olive-ai"):
return
try:
log.info('Installing Olive')
install('onnx', 'onnx', ignore=True)
install('olive-ai', 'olive-ai', ignore=True)
import olive.workflows # pylint: disable=unused-import
except Exception as e:
log.error(f'Olive: Failed to load olive-ai: {e}')
else:
log.info('Olive: Please restart webui session.')