|
from typing import Dict, Optional, Tuple, OrderedDict |
|
from transformers import CLIPTextConfig |
|
from diffusers import UNet2DConditionModel |
|
|
|
import torch |
|
|
|
from optimum.exporters.onnx.model_configs import VisionOnnxConfig, NormalizedConfig, DummyVisionInputGenerator, DummyTimestepInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqDecoderTextInputGenerator |
|
from optimum.exporters.openvino import main_export |
|
from optimum.utils.input_generators import DummyInputGenerator, DEFAULT_DUMMY_SHAPES |
|
from optimum.utils.normalized_config import NormalizedTextConfig |
|
|
|
|
|
|
|
class CustomDummyTimestepInputGenerator(DummyInputGenerator): |
|
""" |
|
Generates dummy time step inputs. |
|
""" |
|
|
|
SUPPORTED_INPUT_NAMES = ( |
|
"timestep", |
|
"timestep_cond", |
|
"text_embeds", |
|
"time_ids", |
|
) |
|
|
|
def __init__( |
|
self, |
|
task: str, |
|
normalized_config: NormalizedConfig, |
|
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], |
|
time_cond_proj_dim: int = 256, |
|
random_batch_size_range: Optional[Tuple[int, int]] = None, |
|
**kwargs, |
|
): |
|
self.task = task |
|
self.vocab_size = normalized_config.vocab_size |
|
self.text_encoder_projection_dim = normalized_config.text_encoder_projection_dim |
|
self.time_ids = 5 if normalized_config.requires_aesthetics_score else 6 |
|
if random_batch_size_range: |
|
low, high = random_batch_size_range |
|
self.batch_size = random.randint(low, high) |
|
else: |
|
self.batch_size = batch_size |
|
self.time_cond_proj_dim = normalized_config.get("time_cond_proj_dim", time_cond_proj_dim) |
|
|
|
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): |
|
shape = [self.batch_size] |
|
|
|
if input_name == "timestep": |
|
return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=int_dtype) |
|
|
|
if input_name == "timestep_cond": |
|
shape.append(self.time_cond_proj_dim) |
|
return self.random_float_tensor(shape, min_value=-1.0, max_value=1.0, framework=framework, dtype=float_dtype) |
|
|
|
|
|
shape.append(self.text_encoder_projection_dim if input_name == "text_embeds" else self.time_ids) |
|
return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) |
|
|
|
class LCMUNetOnnxConfig(VisionOnnxConfig): |
|
ATOL_FOR_VALIDATION = 1e-3 |
|
|
|
|
|
DEFAULT_ONNX_OPSET = 14 |
|
|
|
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( |
|
image_size="sample_size", |
|
num_channels="in_channels", |
|
hidden_size="cross_attention_dim", |
|
vocab_size="norm_num_groups", |
|
allow_new=True, |
|
) |
|
|
|
DUMMY_INPUT_GENERATOR_CLASSES = ( |
|
DummyVisionInputGenerator, |
|
CustomDummyTimestepInputGenerator, |
|
DummySeq2SeqDecoderTextInputGenerator, |
|
) |
|
|
|
@property |
|
def inputs(self) -> Dict[str, Dict[int, str]]: |
|
common_inputs = OrderedDict({ |
|
"sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, |
|
"timestep": {0: "steps"}, |
|
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, |
|
"timestep_cond": {0: "batch_size"}, |
|
}) |
|
|
|
|
|
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": |
|
common_inputs["text_embeds"] = {0: "batch_size"} |
|
common_inputs["time_ids"] = {0: "batch_size"} |
|
|
|
return common_inputs |
|
|
|
@property |
|
def outputs(self) -> Dict[str, Dict[int, str]]: |
|
return { |
|
"out_sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, |
|
} |
|
|
|
@property |
|
def torch_to_onnx_output_map(self) -> Dict[str, str]: |
|
return { |
|
"sample": "out_sample", |
|
} |
|
|
|
def generate_dummy_inputs(self, framework: str = "pt", **kwargs): |
|
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) |
|
dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0] |
|
|
|
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": |
|
dummy_inputs["added_cond_kwargs"] = { |
|
"text_embeds": dummy_inputs.pop("text_embeds"), |
|
"time_ids": dummy_inputs.pop("time_ids"), |
|
} |
|
|
|
return dummy_inputs |
|
|
|
def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]: |
|
return self.inputs |
|
|
|
model_id = "SimianLuo/LCM_Dreamshaper_v7" |
|
|
|
text_encoder_config = CLIPTextConfig.from_pretrained(model_id, subfolder = "text_encoder") |
|
unet_config = UNet2DConditionModel.from_pretrained(model_id, subfolder = "unet").config |
|
|
|
unet_config.text_encoder_projection_dim = text_encoder_config.projection_dim |
|
unet_config.requires_aesthetics_score = False |
|
|
|
custom_onnx_configs = { |
|
"unet": LCMUNetOnnxConfig(config = unet_config, task = "semantic-segmentation") |
|
} |
|
|
|
main_export(model_name_or_path = model_id, output = "./", task = "stable-diffusion", fp16 = False, int8 = False, custom_onnx_configs = custom_onnx_configs) |
|
|