File size: 9,983 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import os
from typing import Type, Callable, TypeVar, Dict, Any
import torch
import diffusers
from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPTextModelWithProjection


class ENVStore:
    __DESERIALIZER: Dict[Type, Callable[[str,], Any]] = {
        bool: lambda x: bool(int(x)),
        int: int,
        str: lambda x: x,
    }
    __SERIALIZER: Dict[Type, Callable[[Any,], str]] = {
        bool: lambda x: str(int(x)),
        int: str,
        str: lambda x: x,
    }

    def __getattr__(self, name: str):
        value = os.environ.get(f"SDNEXT_OLIVE_{name}", None)
        if value is None:
            return
        ty = self.__class__.__annotations__[name]
        deserialize = self.__DESERIALIZER[ty]
        return deserialize(value)

    def __setattr__(self, name: str, value) -> None:
        if name not in self.__class__.__annotations__:
            return
        ty = self.__class__.__annotations__[name]
        serialize = self.__SERIALIZER[ty]
        os.environ[f"SDNEXT_OLIVE_{name}"] = serialize(value)

    def __delattr__(self, name: str) -> None:
        if name not in self.__class__.__annotations__:
            return
        key = f"SDNEXT_OLIVE_{name}"
        if key not in os.environ:
            return
        os.environ.pop(key)


class OliveOptimizerConfig(ENVStore):
    from_diffusers_cache: bool

    is_sdxl: bool

    vae: str
    vae_sdxl_fp16_fix: bool

    width: int
    height: int
    batch_size: int

    cross_attention_dim: int
    time_ids_size: int


config = OliveOptimizerConfig()


def get_variant():
    from modules.shared import opts

    if opts.diffusers_model_load_variant == 'default':
        from modules import devices

        if devices.dtype == torch.float16:
            return 'fp16'

        return None
    elif opts.diffusers_model_load_variant == 'fp32':
        return None
    else:
        return opts.diffusers_model_load_variant


def get_loader_arguments(no_variant: bool = False):
    kwargs = {}

    if config.from_diffusers_cache:
        from modules.shared import opts
        kwargs["cache_dir"] = opts.diffusers_dir
        if not no_variant:
            kwargs["variant"] = get_variant()

    return kwargs


T = TypeVar("T")
def from_pretrained(cls: Type[T], pretrained_model_name_or_path: os.PathLike, *args, no_variant: bool = False, **kwargs) -> T:
    pretrained_model_name_or_path = str(pretrained_model_name_or_path)
    if pretrained_model_name_or_path.endswith(".onnx"):
        cls = diffusers.OnnxRuntimeModel
        pretrained_model_name_or_path = os.path.dirname(pretrained_model_name_or_path)
    return cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs, **get_loader_arguments(no_variant))


# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------


# Helper latency-only dataloader that creates random tensors with no label
class RandomDataLoader:
    def __init__(self, create_inputs_func, batchsize, torch_dtype):
        self.create_input_func = create_inputs_func
        self.batchsize = batchsize
        self.torch_dtype = torch_dtype

    def __getitem__(self, idx):
        label = None
        return self.create_input_func(self.batchsize, self.torch_dtype), label

# -----------------------------------------------------------------------------
# TEXT ENCODER
# -----------------------------------------------------------------------------


def text_encoder_inputs(batchsize, torch_dtype):
    input_ids = torch.zeros((config.batch_size, 77), dtype=torch_dtype)
    return {
        "input_ids": input_ids,
        "output_hidden_states": True,
    } if config.is_sdxl else input_ids


def text_encoder_load(model_name):
    model = from_pretrained(CLIPTextModel, model_name, subfolder="text_encoder")
    return model


def text_encoder_conversion_inputs(model):
    return text_encoder_inputs(1, torch.int32)


def text_encoder_data_loader(data_dir, batchsize, *_, **__):
    return RandomDataLoader(text_encoder_inputs, config.batch_size, torch.int32)


# -----------------------------------------------------------------------------
# TEXT ENCODER 2
# -----------------------------------------------------------------------------


def text_encoder_2_inputs(batchsize, torch_dtype):
    return {
        "input_ids": torch.zeros((config.batch_size, 77), dtype=torch_dtype),
        "output_hidden_states": True,
    }


def text_encoder_2_load(model_name):
    model = from_pretrained(CLIPTextModelWithProjection, model_name, subfolder="text_encoder_2")
    return model


def text_encoder_2_conversion_inputs(model):
    return text_encoder_2_inputs(1, torch.int64)


def text_encoder_2_data_loader(data_dir, batchsize, *_, **__):
    return RandomDataLoader(text_encoder_2_inputs, config.batch_size, torch.int64)


# -----------------------------------------------------------------------------
# UNET
# -----------------------------------------------------------------------------


def unet_inputs(batchsize, torch_dtype, is_conversion_inputs=False):
    if config.is_sdxl:
        inputs = {
            "sample": torch.rand((2 * config.batch_size, 4, config.height // 8, config.width // 8), dtype=torch_dtype),
            "timestep": torch.rand((1,), dtype=torch_dtype),
            "encoder_hidden_states": torch.rand((2 * config.batch_size, 77, config.cross_attention_dim), dtype=torch_dtype),
        }

        if is_conversion_inputs:
            inputs["additional_inputs"] = {
                "added_cond_kwargs": {
                    "text_embeds": torch.rand((2 * config.batch_size, 1280), dtype=torch_dtype),
                    "time_ids": torch.rand((2 * config.batch_size, config.time_ids_size), dtype=torch_dtype),
                }
            }
        else:
            inputs["text_embeds"] = torch.rand((2 * config.batch_size, 1280), dtype=torch_dtype)
            inputs["time_ids"] = torch.rand((2 * config.batch_size, config.time_ids_size), dtype=torch_dtype)
    else:
        inputs = {
            "sample": torch.rand((config.batch_size, 4, config.height // 8, config.width // 8), dtype=torch_dtype),
            "timestep": torch.rand((config.batch_size,), dtype=torch_dtype),
            "encoder_hidden_states": torch.rand((config.batch_size, 77, config.cross_attention_dim), dtype=torch_dtype),
        }

        # use as kwargs since they won't be in the correct position if passed along with the tuple of inputs
        kwargs = {
            "return_dict": False,
        }
        if is_conversion_inputs:
            inputs["additional_inputs"] = {
                **kwargs,
                "added_cond_kwargs": {
                    "text_embeds": torch.rand((1, 1280), dtype=torch_dtype),
                    "time_ids": torch.rand((1, 5), dtype=torch_dtype),
                },
            }
        else:
            inputs.update(kwargs)
            inputs["onnx::Concat_4"] = torch.rand((1, 1280), dtype=torch_dtype)
            inputs["onnx::Shape_5"] = torch.rand((1, 5), dtype=torch_dtype)

    return inputs


def unet_load(model_name):
    model = from_pretrained(diffusers.UNet2DConditionModel, model_name, subfolder="unet")
    return model


def unet_conversion_inputs(model):
    return tuple(unet_inputs(1, torch.float32, True).values())


def unet_data_loader(data_dir, batchsize, *_, **__):
    return RandomDataLoader(unet_inputs, config.batch_size, torch.float16)


# -----------------------------------------------------------------------------
# VAE ENCODER
# -----------------------------------------------------------------------------


def vae_encoder_inputs(batchsize, torch_dtype):
    return {
        "sample": torch.rand((config.batch_size, 3, config.height, config.width), dtype=torch_dtype),
        "return_dict": False,
    }


def vae_encoder_load(model_name):
    subfolder = "vae_encoder" if os.path.isdir(os.path.join(model_name, "vae_encoder")) else "vae"

    if config.vae_sdxl_fp16_fix:
        model_name = "madebyollin/sdxl-vae-fp16-fix"
        subfolder = ""

    if config.vae is None:
        model = from_pretrained(diffusers.AutoencoderKL, model_name, subfolder=subfolder, no_variant=config.vae_sdxl_fp16_fix)
    else:
        model = diffusers.AutoencoderKL.from_single_file(config.vae)

    model.forward = lambda sample, return_dict: model.encode(sample, return_dict)[0].sample()

    return model


def vae_encoder_conversion_inputs(model):
    return tuple(vae_encoder_inputs(1, torch.float32).values())


def vae_encoder_data_loader(data_dir, batchsize, *_, **__):
    return RandomDataLoader(vae_encoder_inputs, config.batch_size, torch.float16)


# -----------------------------------------------------------------------------
# VAE DECODER
# -----------------------------------------------------------------------------


def vae_decoder_inputs(batchsize, torch_dtype):
    return {
        "latent_sample": torch.rand((config.batch_size, 4, config.height // 8, config.width // 8), dtype=torch_dtype),
        "return_dict": False,
    }


def vae_decoder_load(model_name):
    subfolder = "vae_decoder" if os.path.isdir(os.path.join(model_name, "vae_decoder")) else "vae"

    if config.vae_sdxl_fp16_fix:
        model_name = "madebyollin/sdxl-vae-fp16-fix"
        subfolder = ""

    if config.vae is None:
        model = from_pretrained(diffusers.AutoencoderKL, model_name, subfolder=subfolder, no_variant=config.vae_sdxl_fp16_fix)
    else:
        model = diffusers.AutoencoderKL.from_single_file(config.vae)

    model.forward = model.decode

    return model


def vae_decoder_conversion_inputs(model):
    return tuple(vae_decoder_inputs(1, torch.float32).values())


def vae_decoder_data_loader(data_dir, batchsize, *_, **__):
    return RandomDataLoader(vae_decoder_inputs, config.batch_size, torch.float16)