|
import mixdq_extension._C |
|
import inspect |
|
from typing import Any, Callable, Dict, List, Optional, Union, Tuple |
|
from collections import namedtuple |
|
import sys |
|
import os |
|
import torch |
|
|
|
from typing import Optional |
|
import torch.nn.functional as F |
|
import math |
|
|
|
import torch.nn as nn |
|
import torch |
|
from torch.ao.quantization import QConfig |
|
|
|
from transformers import ( |
|
CLIPImageProcessor, |
|
CLIPTextModel, |
|
CLIPTextModelWithProjection, |
|
CLIPTokenizer, |
|
CLIPVisionModelWithProjection, |
|
) |
|
|
|
from diffusers import StableDiffusionXLPipeline |
|
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor |
|
from diffusers.loaders import ( |
|
FromSingleFileMixin, |
|
IPAdapterMixin, |
|
StableDiffusionXLLoraLoaderMixin, |
|
TextualInversionLoaderMixin, |
|
) |
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel |
|
from diffusers.schedulers import KarrasDiffusionSchedulers |
|
from diffusers.models.lora import adjust_lora_scale_text_encoder |
|
from diffusers.models.attention_processor import ( |
|
AttnProcessor2_0, |
|
LoRAAttnProcessor2_0, |
|
LoRAXFormersAttnProcessor, |
|
XFormersAttnProcessor, |
|
) |
|
from diffusers.utils import ( |
|
USE_PEFT_BACKEND, |
|
deprecate, |
|
is_invisible_watermark_available, |
|
is_torch_xla_available, |
|
logging, |
|
replace_example_docstring, |
|
scale_lora_layers, |
|
unscale_lora_layers, |
|
) |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.ao.quantization import QConfig, MinMaxObserver, PlaceholderObserver, QuantStub, DeQuantStub |
|
|
|
import copy |
|
import itertools |
|
import warnings |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.ao.nn.quantized as nnq |
|
from torch.ao.nn.intrinsic import _FusedModule |
|
|
|
from torch.ao.quantization.quantization_mappings import ( |
|
get_default_dynamic_quant_module_mappings, |
|
get_default_static_quant_module_mappings, |
|
get_default_static_quant_reference_module_mappings, |
|
get_default_qat_module_mappings, |
|
get_default_qconfig_propagation_list, |
|
no_observer_set, |
|
_has_special_act_post_process, |
|
_get_special_act_post_process, |
|
) |
|
from torch.ao.quantization.utils import get_qparam_dict, has_no_children_ignoring_parametrizations |
|
from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper |
|
from torch.ao.quantization.qconfig import ( |
|
_add_module_to_qconfig_obs_ctr, |
|
default_dynamic_qconfig, |
|
float16_dynamic_qconfig, |
|
float_qparams_weight_only_qconfig, |
|
float_qparams_weight_only_qconfig_4bit, |
|
_activation_is_memoryless) |
|
from torch.nn.utils.parametrize import type_before_parametrizations |
|
from torch.ao.quantization.observer import _is_activation_post_process |
|
|
|
|
|
from torch.ao.quantization.observer import ( |
|
_is_activation_post_process as is_activation_post_process |
|
) |
|
|
|
if is_invisible_watermark_available(): |
|
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker |
|
|
|
|
|
|
|
|
|
|
|
|
|
XLA_AVAILABLE = False |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
EXAMPLE_DOC_STRING = """ |
|
Examples: |
|
```py |
|
>>> import torch |
|
>>> from diffusers import StableDiffusionXLPipeline |
|
|
|
>>> pipe = StableDiffusionXLPipeline.from_pretrained( |
|
... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 |
|
... ) |
|
>>> pipe = pipe.to("cuda") |
|
|
|
>>> prompt = "a photo of an astronaut riding a horse on mars" |
|
>>> image = pipe(prompt).images[0] |
|
``` |
|
""" |
|
|
|
|
|
|
|
|
|
qlinear = mixdq_extension._C.qlinear_w8_a8_ohalf |
|
quantize_per_tensor = mixdq_extension._C.quantize_per_tensor_to_int8 |
|
|
|
|
|
def qconv2d( |
|
input_int, |
|
weight_int, |
|
weight_scale, |
|
input_scale, |
|
input_zp, |
|
bias=None, |
|
stride=1, |
|
padding=0, |
|
): |
|
dilation = 1 |
|
if padding > 0: |
|
return mixdq_extension._C.qconv2d_with_padding_w8_a8_ohalf( |
|
input_int, weight_int, weight_scale, input_scale, input_zp, |
|
bias, stride, padding, dilation |
|
) |
|
if padding == 0: |
|
return mixdq_extension._C.qconv2d_w8_a8_ohalf( |
|
input_int, weight_int, weight_scale, input_scale, input_zp, |
|
bias, stride, padding, dilation |
|
) |
|
else: |
|
raise ValueError(f"Padding should be integers >= 0, got {padding}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
"get_default_custom_config_dict", |
|
"propagate_qconfig_", |
|
"add_quant_dequant", |
|
"prepare", |
|
"quantize", |
|
"quantize_dynamic", |
|
"prepare_qat", |
|
"quantize_qat", |
|
"convert", |
|
"swap_module", |
|
'QuantizedLinear', |
|
'QuantizedConv2d', |
|
] |
|
|
|
|
|
_DEFAULT_CUSTOM_CONFIG_DICT = { |
|
'float_to_observed_custom_module_class': { |
|
nn.LSTM: nn.quantizable.LSTM, |
|
nn.MultiheadAttention: nn.quantizable.MultiheadAttention, |
|
}, |
|
'observed_to_quantized_custom_module_class': { |
|
nn.quantizable.LSTM: nn.quantized.LSTM, |
|
nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, |
|
} |
|
} |
|
|
|
_SPLIT = [1280, 1280, 1280, 1280, 640, 640, 640, 320, 320] |
|
|
|
|
|
_NUM = 0 |
|
|
|
|
|
def get_default_custom_config_dict(): |
|
r"""Defines the default custom config dict. |
|
""" |
|
return _DEFAULT_CUSTOM_CONFIG_DICT |
|
|
|
|
|
def _propagate_qconfig_helper(module, qconfig_dict, |
|
qconfig_parent=None, prefix='', prepare_custom_config_dict=None): |
|
r"""This is a helper function for `propagate_qconfig_` |
|
|
|
Args: |
|
module: input module |
|
qconfig_dict: dictionary that maps from name of submodule to quantization |
|
configuration |
|
qconfig_parent: quantization config of parent module, we will fallback to |
|
this config when there is no specified config for current |
|
module |
|
prefix: corresponding prefix of the current module, used as key in |
|
qconfig_dict |
|
prepare_custom_config_dict: dictionary for custom handling of modules |
|
see docs for :func:`~torch.ao.quantization.prepare_fx` |
|
|
|
Return: |
|
None, module is modified inplace with qconfig attached |
|
""" |
|
|
|
module_qconfig = qconfig_dict.get( |
|
type_before_parametrizations(module), qconfig_parent) |
|
module_qconfig = qconfig_dict.get(prefix, module_qconfig) |
|
module_qconfig = getattr(module, 'qconfig', module_qconfig) |
|
|
|
torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module) |
|
|
|
qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( |
|
module_qconfig, module) |
|
module.qconfig = qconfig_with_device_check |
|
|
|
for name, child in module.named_children(): |
|
module_prefix = prefix + '.' + name if prefix else name |
|
|
|
if prepare_custom_config_dict is None or not ( |
|
name in prepare_custom_config_dict.get( |
|
"non_traceable_module_name", []) |
|
or type(child) in prepare_custom_config_dict.get("non_traceable_module_class", []) |
|
): |
|
_propagate_qconfig_helper( |
|
child, qconfig_dict, qconfig_with_device_check, module_prefix |
|
) |
|
|
|
|
|
def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None): |
|
r"""Propagate qconfig through the module hierarchy and assign `qconfig` |
|
attribute on each leaf module |
|
|
|
Args: |
|
module: input module |
|
qconfig_dict: dictionary that maps from name or type of submodule to |
|
quantization configuration, qconfig applies to all submodules of a |
|
given module unless qconfig for the submodules are specified (when |
|
the submodule already has qconfig attribute) |
|
prepare_custom_config_dict: dictionary for custom handling of modules |
|
see docs for :func:`~torch.ao.quantization.prepare_fx` |
|
|
|
Return: |
|
None, module is modified inplace with qconfig attached |
|
""" |
|
if qconfig_dict is None: |
|
qconfig_dict = {} |
|
if prepare_custom_config_dict is None: |
|
prepare_custom_config_dict = {} |
|
_propagate_qconfig_helper( |
|
module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict) |
|
|
|
|
|
def _observer_forward_hook(self, input, output): |
|
r"""Forward hook that calls observer on the output |
|
""" |
|
return self.activation_post_process(output) |
|
|
|
|
|
def _observer_forward_pre_hook(self, input): |
|
r"""Forward pre hook that calls observer on the output |
|
""" |
|
return self.activation_post_process(input[0]) |
|
|
|
|
|
def _register_activation_post_process_hook(module, pre_hook=False): |
|
assert hasattr(module, 'activation_post_process'), \ |
|
'Expect activation_post_process attribute already attached to the module' |
|
if pre_hook: |
|
handle = module.register_forward_pre_hook( |
|
_observer_forward_pre_hook, prepend=True |
|
) |
|
else: |
|
handle = module.register_forward_hook( |
|
_observer_forward_hook, prepend=True |
|
) |
|
|
|
|
|
def _add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None): |
|
r"""Add observer for the leaf child of the module. |
|
|
|
This function insert observer module to all leaf child module that |
|
has a valid qconfig attribute. |
|
|
|
Args: |
|
module: input module with qconfig attributes for all the leaf modules that we want to quantize |
|
qconfig_propagation_list: a list of quantizable modules that will have observers added to them |
|
if they are leaf nodes |
|
device: parent device, if any |
|
non_leaf_module_list: list of non-leaf modules we want to add observer |
|
|
|
Return: |
|
None, module is modified inplace with added observer modules and forward_hooks |
|
""" |
|
if qconfig_propagation_list is None: |
|
qconfig_propagation_list = get_default_qconfig_propagation_list() |
|
|
|
if custom_module_class_mapping is None: |
|
custom_module_class_mapping = {} |
|
|
|
|
|
if device is None: |
|
devices = _get_unique_devices_(module) |
|
assert len(devices) <= 1, ( |
|
f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" |
|
) |
|
device = next(iter(devices)) if len(devices) > 0 else None |
|
|
|
def get_activation_post_process(qconfig, device, special_act_post_process=None): |
|
activation = qconfig.activation( |
|
) if special_act_post_process is None else special_act_post_process() |
|
if device is not None: |
|
activation.to(device) |
|
return activation |
|
|
|
def needs_observation(m): |
|
return hasattr(m, 'qconfig') and m.qconfig is not None |
|
|
|
def insert_activation_post_process(m, special_act_post_process=None): |
|
""" Adds an activation post process module and register |
|
a pre or post hook that calls the module |
|
""" |
|
|
|
if needs_observation(m) and not isinstance(m, DeQuantStub): |
|
|
|
m.add_module('activation_post_process', get_activation_post_process( |
|
m.qconfig, device, special_act_post_process)) |
|
|
|
|
|
_register_activation_post_process_hook( |
|
m, pre_hook=_activation_is_memoryless(m.qconfig)) |
|
|
|
for name, child in module.named_children(): |
|
|
|
if type_before_parametrizations(child) in [nn.Dropout]: |
|
continue |
|
elif issubclass(type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)): |
|
if needs_observation(child): |
|
assert hasattr(child, "activation_post_process"), ( |
|
f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" |
|
) |
|
child.activation_post_process = get_activation_post_process( |
|
child.qconfig, device) |
|
elif isinstance(child, _FusedModule): |
|
|
|
if needs_observation(child): |
|
insert_activation_post_process(child) |
|
elif non_leaf_module_list is not None and type_before_parametrizations(child) in non_leaf_module_list: |
|
if needs_observation(child): |
|
insert_activation_post_process(child) |
|
elif _has_special_act_post_process(child): |
|
special_act_post_process = _get_special_act_post_process(child) |
|
insert_activation_post_process(child, special_act_post_process) |
|
elif needs_observation(child) and type_before_parametrizations(child) in custom_module_class_mapping: |
|
observed_child = custom_module_class_mapping[type_before_parametrizations( |
|
child)].from_float(child) |
|
setattr(module, name, observed_child) |
|
|
|
|
|
if custom_module_class_mapping[type_before_parametrizations(child)] not in no_observer_set(): |
|
insert_activation_post_process(observed_child) |
|
else: |
|
_add_observer_(child, qconfig_propagation_list, |
|
non_leaf_module_list, device, custom_module_class_mapping) |
|
|
|
|
|
|
|
if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \ |
|
and type_before_parametrizations(module) in qconfig_propagation_list: |
|
insert_activation_post_process(module) |
|
|
|
|
|
def _get_unique_devices_(module): |
|
return {p.device for p in module.parameters()} | \ |
|
{p.device for p in module.buffers()} |
|
|
|
|
|
def add_quant_dequant(module): |
|
r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig |
|
Note that this function will modify the children of module inplace and it |
|
can return a new module which wraps the input module as well. |
|
|
|
Args: |
|
module: input module with qconfig attributes for all the leaf modules |
|
that we want to quantize |
|
|
|
Return: |
|
Either the inplace modified module with submodules wrapped in |
|
`QuantWrapper` based on qconfig or a new `QuantWrapper` module which |
|
wraps the input module, the latter case only happens when the input |
|
module is a leaf module and we want to quantize it. |
|
""" |
|
if has_no_children_ignoring_parametrizations(module) and hasattr(module, 'qconfig') and module.qconfig: |
|
return QuantWrapper(module) |
|
|
|
for name, child in module.named_children(): |
|
module._modules[name] = add_quant_dequant(child) |
|
return module |
|
|
|
|
|
def prepare(model, inplace=False, allow_list=None, |
|
observer_non_leaf_module_list=None, |
|
prepare_custom_config_dict=None): |
|
r"""Prepares a copy of the model for quantization calibration or quantization-aware training. |
|
|
|
Quantization configuration should be assigned preemptively |
|
to individual submodules in `.qconfig` attribute. |
|
|
|
The model will be attached with observer or fake quant modules, and qconfig |
|
will be propagated. |
|
|
|
Args: |
|
`model`: input model to be modified in-place |
|
`inplace`: carry out model transformations in-place, the original module is mutated |
|
`allow_list`: list of quantizable modules |
|
`observer_non_leaf_module_list`: list of non-leaf modules we want to add observer |
|
`prepare_custom_config_dict`: customization configuration dictionary for prepare function |
|
|
|
.. code-block:: python |
|
|
|
# Example of prepare_custom_config_dict: |
|
prepare_custom_config_dict = { |
|
# user will manually define the corresponding observed |
|
# module class which has a from_float class method that converts |
|
# float custom module to observed custom module |
|
"float_to_observed_custom_module_class": { |
|
CustomModule: ObservedCustomModule |
|
} |
|
} |
|
|
|
""" |
|
torch._C._log_api_usage_once("quantization_api.quantize.prepare") |
|
if prepare_custom_config_dict is None: |
|
prepare_custom_config_dict = get_default_custom_config_dict() |
|
custom_module_class_mapping = prepare_custom_config_dict.get( |
|
"float_to_observed_custom_module_class", {}) |
|
|
|
if not inplace: |
|
model = copy.deepcopy(model) |
|
|
|
|
|
qconfig_propagation_list = allow_list |
|
if allow_list is None: |
|
qconfig_propagation_list = get_default_qconfig_propagation_list() |
|
propagate_qconfig_(model, qconfig_dict=None) |
|
|
|
|
|
if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()): |
|
warnings.warn("None of the submodule got qconfig applied. Make sure you " |
|
"passed correct configuration through `qconfig_dict` or " |
|
"by assigning the `.qconfig` attribute directly on submodules") |
|
|
|
_add_observer_( |
|
model, qconfig_propagation_list, observer_non_leaf_module_list, |
|
custom_module_class_mapping=custom_module_class_mapping) |
|
return model |
|
|
|
|
|
def _remove_activation_post_process(module): |
|
|
|
|
|
if hasattr(module, 'activation_post_process') and \ |
|
_is_activation_post_process(module.activation_post_process): |
|
delattr(module, 'activation_post_process') |
|
|
|
|
|
def remove_hooks(pre_hook=False): |
|
hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks |
|
observer_hook = _observer_forward_pre_hook if pre_hook else _observer_forward_hook |
|
handle_ids_to_remove = set() |
|
for handle_id, hook_fn in hook_map.items(): |
|
if hook_fn is observer_hook: |
|
handle_ids_to_remove.add(handle_id) |
|
for handle_id in handle_ids_to_remove: |
|
hook_map.pop(handle_id) |
|
|
|
remove_hooks(pre_hook=True) |
|
remove_hooks(pre_hook=False) |
|
|
|
|
|
|
|
|
|
def _remove_qconfig(module): |
|
r"""Clean up the qconfig left in the module so that new qconfig can be |
|
propagated. |
|
|
|
Args: |
|
module: module to be cleaned up |
|
""" |
|
for child in module.children(): |
|
_remove_qconfig(child) |
|
|
|
if hasattr(module, "qconfig"): |
|
del module.qconfig |
|
|
|
_remove_activation_post_process(module) |
|
|
|
|
|
def quantize(model, run_fn, run_args, mapping=None, inplace=False): |
|
r"""Quantize the input float model with post training static quantization. |
|
|
|
First it will prepare the model for calibration, then it calls |
|
`run_fn` which will run the calibration step, after that we will |
|
convert the model to a quantized model. |
|
|
|
Args: |
|
model: input float model |
|
run_fn: a calibration function for calibrating the prepared model |
|
run_args: positional arguments for `run_fn` |
|
inplace: carry out model transformations in-place, the original module is mutated |
|
mapping: correspondence between original module types and quantized counterparts |
|
|
|
Return: |
|
Quantized model. |
|
""" |
|
torch._C._log_api_usage_once("quantization_api.quantize.quantize") |
|
if mapping is None: |
|
mapping = get_default_static_quant_module_mappings() |
|
if not inplace: |
|
model = copy.deepcopy(model) |
|
model.eval() |
|
prepare(model, inplace=True) |
|
run_fn(model, *run_args) |
|
convert(model, mapping, inplace=True) |
|
return model |
|
|
|
|
|
def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, |
|
mapping=None, inplace=False): |
|
r"""Converts a float model to dynamic (i.e. weights-only) quantized model. |
|
|
|
Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. |
|
|
|
For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization |
|
by default is performed for layers with large weights size - i.e. Linear and RNN variants. |
|
|
|
Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. |
|
If `qconfig` is provided, the `dtype` argument is ignored. |
|
|
|
Args: |
|
model: input model |
|
qconfig_spec: Either: |
|
|
|
- A dictionary that maps from name or type of submodule to quantization |
|
configuration, qconfig applies to all submodules of a given |
|
module unless qconfig for the submodules are specified (when the |
|
submodule already has qconfig attribute). Entries in the dictionary |
|
need to be QConfig instances. |
|
|
|
- A set of types and/or submodule names to apply dynamic quantization to, |
|
in which case the `dtype` argument is used to specify the bit-width |
|
|
|
inplace: carry out model transformations in-place, the original module is mutated |
|
mapping: maps type of a submodule to a type of corresponding dynamically quantized version |
|
with which the submodule needs to be replaced |
|
|
|
""" |
|
torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic") |
|
if qconfig_spec is None: |
|
if dtype == torch.qint8: |
|
qconfig_spec = { |
|
nn.Linear: default_dynamic_qconfig, |
|
nn.LSTM: default_dynamic_qconfig, |
|
nn.GRU: default_dynamic_qconfig, |
|
nn.LSTMCell: default_dynamic_qconfig, |
|
nn.RNNCell: default_dynamic_qconfig, |
|
nn.GRUCell: default_dynamic_qconfig, |
|
} |
|
elif dtype == torch.float16: |
|
qconfig_spec = { |
|
nn.Linear: float16_dynamic_qconfig, |
|
nn.LSTM: float16_dynamic_qconfig, |
|
nn.GRU: float16_dynamic_qconfig, |
|
nn.LSTMCell: float16_dynamic_qconfig, |
|
nn.RNNCell: float16_dynamic_qconfig, |
|
nn.GRUCell: float16_dynamic_qconfig, |
|
} |
|
elif dtype == torch.quint8: |
|
qconfig_spec = { |
|
nn.EmbeddingBag: float_qparams_weight_only_qconfig, |
|
nn.Embedding: float_qparams_weight_only_qconfig, |
|
} |
|
elif dtype == torch.quint4x2: |
|
qconfig_spec = { |
|
nn.EmbeddingBag: float_qparams_weight_only_qconfig_4bit, |
|
} |
|
else: |
|
raise ValueError( |
|
f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please") |
|
elif isinstance(qconfig_spec, set): |
|
if dtype is torch.qint8: |
|
default_qconfig = default_dynamic_qconfig |
|
elif dtype is torch.float16: |
|
default_qconfig = float16_dynamic_qconfig |
|
elif dtype is torch.quint8: |
|
default_qconfig = float_qparams_weight_only_qconfig |
|
elif dtype is torch.quint4x2: |
|
default_qconfig = float_qparams_weight_only_qconfig_4bit |
|
else: |
|
raise RuntimeError( |
|
'Unknown dtype specified for quantize_dynamic: ', str(dtype)) |
|
qconfig_spec = dict( |
|
zip(qconfig_spec, itertools.repeat(default_qconfig))) |
|
|
|
if mapping is None: |
|
mapping = get_default_dynamic_quant_module_mappings() |
|
|
|
if not inplace: |
|
model = copy.deepcopy(model) |
|
model.eval() |
|
propagate_qconfig_(model, qconfig_spec) |
|
convert(model, mapping, inplace=True) |
|
return model |
|
|
|
|
|
def prepare_qat(model, mapping=None, inplace=False): |
|
r""" |
|
Prepares a copy of the model for quantization calibration or |
|
quantization-aware training and converts it to quantized version. |
|
|
|
Quantization configuration should be assigned preemptively |
|
to individual submodules in `.qconfig` attribute. |
|
|
|
Args: |
|
model: input model to be modified in-place |
|
mapping: dictionary that maps float modules to quantized modules to be |
|
replaced. |
|
inplace: carry out model transformations in-place, the original module |
|
is mutated |
|
""" |
|
torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") |
|
assert model.training, "prepare_qat only works on models in training mode" |
|
if mapping is None: |
|
mapping = get_default_qat_module_mappings() |
|
|
|
if not inplace: |
|
model = copy.deepcopy(model) |
|
|
|
propagate_qconfig_(model, qconfig_dict=None) |
|
convert(model, mapping=mapping, inplace=True, remove_qconfig=False) |
|
prepare(model, observer_non_leaf_module_list=set( |
|
mapping.values()), inplace=True) |
|
return model |
|
|
|
|
|
def quantize_qat(model, run_fn, run_args, inplace=False): |
|
r"""Do quantization aware training and output a quantized model |
|
|
|
Args: |
|
model: input model |
|
run_fn: a function for evaluating the prepared model, can be a |
|
function that simply runs the prepared model or a training |
|
loop |
|
run_args: positional arguments for `run_fn` |
|
|
|
Return: |
|
Quantized model. |
|
""" |
|
torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat") |
|
if not inplace: |
|
model = copy.deepcopy(model) |
|
model.train() |
|
prepare_qat(model, inplace=True) |
|
run_fn(model, *run_args) |
|
convert(model, inplace=True) |
|
return model |
|
|
|
|
|
def convert( |
|
module, mapping=None, inplace=False, remove_qconfig=True, |
|
is_reference=False, convert_custom_config_dict=None, ckpt=None): |
|
r"""Converts submodules in input module to a different module according to `mapping` |
|
by calling `from_float` method on the target module class. And remove qconfig at the |
|
end if remove_qconfig is set to True. |
|
|
|
Args: |
|
`module`: prepared and calibrated module |
|
`mapping`: a dictionary that maps from source module type to target |
|
module type, can be overwritten to allow swapping user defined |
|
Modules |
|
`inplace`: carry out model transformations in-place, the original module |
|
is mutated |
|
`convert_custom_config_dict`: custom configuration dictionary for convert function |
|
|
|
.. code-block:: python |
|
|
|
# Example of convert_custom_config_dict: |
|
convert_custom_config_dict = { |
|
# user will manually define the corresponding quantized |
|
# module class which has a from_observed class method that converts |
|
# observed custom module to quantized custom module |
|
"observed_to_quantized_custom_module_class": { |
|
ObservedCustomModule: QuantizedCustomModule |
|
} |
|
} |
|
|
|
""" |
|
torch._C._log_api_usage_once("quantization_api.quantize.convert") |
|
if not inplace: |
|
module = copy.deepcopy(module) |
|
_convert( |
|
module, mapping, inplace=True, is_reference=is_reference, |
|
convert_custom_config_dict=convert_custom_config_dict, ckpt=ckpt) |
|
if remove_qconfig: |
|
_remove_qconfig(module) |
|
return module |
|
|
|
|
|
def _convert( |
|
module, mapping=None, inplace=False, |
|
is_reference=False, convert_custom_config_dict=None, ckpt=None): |
|
r"""Converts submodules in input module to a different module according to `mapping` |
|
by calling `from_float` method on the target module class |
|
|
|
Args: |
|
module: input module |
|
mapping: a dictionary that maps from source module type to target |
|
module type, can be overwritten to allow swapping user defined |
|
Modules |
|
inplace: carry out model transformations in-place, the original module |
|
is mutated |
|
is_reference: a flag to enable quantized reference module |
|
|
|
""" |
|
if mapping is None: |
|
mapping = get_default_static_quant_reference_module_mappings() if is_reference \ |
|
else get_default_static_quant_module_mappings() |
|
if convert_custom_config_dict is None: |
|
convert_custom_config_dict = get_default_custom_config_dict() |
|
custom_module_class_mapping = convert_custom_config_dict.get( |
|
"observed_to_quantized_custom_module_class", {}) |
|
|
|
if not inplace: |
|
module = copy.deepcopy(module) |
|
reassign = {} |
|
for name, mod in module.named_children(): |
|
|
|
|
|
if not isinstance(mod, _FusedModule) and \ |
|
type_before_parametrizations(mod) not in custom_module_class_mapping: |
|
_convert(mod, mapping, True, |
|
is_reference, convert_custom_config_dict, ckpt=ckpt) |
|
reassign[name] = swap_module( |
|
mod, mapping, custom_module_class_mapping, ckpt=ckpt) |
|
|
|
for key, value in reassign.items(): |
|
module._modules[key] = value |
|
|
|
return module |
|
|
|
|
|
def swap_module(mod, mapping, custom_module_class_mapping, ckpt=None): |
|
global _NUM |
|
r"""Swaps the module if it has a quantized counterpart and it has an |
|
`observer` attached. |
|
|
|
Args: |
|
mod: input module |
|
mapping: a dictionary that maps from nn module to nnq module |
|
|
|
Return: |
|
The corresponding quantized module of `mod` |
|
""" |
|
new_mod = mod |
|
if hasattr(mod, 'qconfig') and mod.qconfig is not None: |
|
swapped = False |
|
if type_before_parametrizations(mod) in custom_module_class_mapping: |
|
new_mod = custom_module_class_mapping[type_before_parametrizations( |
|
mod)].from_observed(mod) |
|
swapped = True |
|
elif type_before_parametrizations(mod) in mapping: |
|
qmod = mapping[type_before_parametrizations(mod)] |
|
if hasattr(qmod, '_IS_REFERENCE') and qmod._IS_REFERENCE: |
|
assert mod.qconfig is not None |
|
weight_post_process = mod.qconfig.weight() |
|
weight_post_process(mod.weight) |
|
weight_qparams = get_qparam_dict(weight_post_process) |
|
if 'up_blocks' in mod.module_name and 'conv_shortcut' in mod.module_name: |
|
|
|
_split = _SPLIT[_NUM] |
|
_NUM = _NUM + 1 |
|
|
|
else: |
|
_split = 0 |
|
new_mod = qmod.from_float(mod, weight_qparams, split=_split) |
|
else: |
|
if 'up_blocks' in mod.module_name and 'conv_shortcut' in mod.module_name: |
|
|
|
_split = _SPLIT[_NUM] |
|
_NUM = _NUM + 1 |
|
|
|
print(f"split at {_split}") |
|
else: |
|
_split = 0 |
|
new_mod = qmod.from_float(mod, split=_split, ckpt=ckpt) |
|
swapped = True |
|
|
|
if swapped: |
|
|
|
for pre_hook_fn in mod._forward_pre_hooks.values(): |
|
new_mod.register_forward_pre_hook(pre_hook_fn) |
|
|
|
|
|
for hook_fn in mod._forward_hooks.values(): |
|
if hook_fn is not _observer_forward_hook: |
|
new_mod.register_forward_hook(hook_fn) |
|
|
|
|
|
devices = _get_unique_devices_(mod) |
|
assert len(devices) <= 1, ( |
|
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" |
|
) |
|
device = next(iter(devices)) if len(devices) > 0 else None |
|
if device: |
|
new_mod.to(device) |
|
return new_mod |
|
|
|
|
|
def _get_observer_dict(mod, target_dict, prefix=""): |
|
r"""Traverse the modules and save all observers into dict. |
|
This is mainly used for quantization accuracy debug |
|
Args: |
|
mod: the top module we want to save all observers |
|
prefix: the prefix for the current module |
|
target_dict: the dictionary used to save all the observers |
|
""" |
|
def get_prefix(prefix): |
|
return prefix if prefix == "" else prefix + '.' |
|
|
|
if hasattr(mod, 'activation_post_process'): |
|
target_dict[get_prefix( |
|
prefix) + 'activation_post_process'] = mod.activation_post_process |
|
for name, child in mod.named_children(): |
|
module_prefix = get_prefix(prefix) + name if prefix else name |
|
_get_observer_dict(child, target_dict, module_prefix) |
|
|
|
|
|
def filter_mod_name_prefix(mod_name): |
|
if 'model.' in mod_name: |
|
pos = mod_name.index('model.') |
|
mod_name = mod_name[pos + 6:] |
|
return mod_name |
|
|
|
|
|
def register_qconfig_from_input_files( |
|
unet, |
|
|
|
w_bit=8, |
|
a_bit=None, |
|
bos=True, |
|
bos_dict=None |
|
): |
|
import yaml |
|
|
|
bw_to_dtype = { |
|
8: torch.qint8, |
|
4: torch.quint4x2, |
|
2: torch.quint4x2, |
|
} |
|
|
|
|
|
|
|
if w_bit==8: |
|
mod_name_to_weight_width = w8_uniform_config |
|
else: |
|
raise RuntimeError("we only support int8 quantization") |
|
|
|
mod_name_to_weight_width_copy = {} |
|
for mod_name, bit_width in mod_name_to_weight_width.items(): |
|
new_name = filter_mod_name_prefix(mod_name) |
|
mod_name_to_weight_width_copy[new_name] = bit_width |
|
mod_name_to_weight_width = mod_name_to_weight_width_copy |
|
|
|
|
|
mod_name_to_weight_width_copy = mod_name_to_weight_width |
|
for name, mod in unet.named_modules(): |
|
if name in mod_name_to_weight_width: |
|
assert not hasattr(mod, 'qconfig') |
|
|
|
w_bitwidth = mod_name_to_weight_width[name] |
|
w_dtype = bw_to_dtype[w_bitwidth] |
|
act_preprocess = PlaceholderObserver.with_args( |
|
dtype=torch.float16) |
|
weight_process = PlaceholderObserver.with_args(dtype=w_dtype) |
|
mod.qconfig = \ |
|
QConfig(activation=act_preprocess, weight=weight_process) |
|
|
|
|
|
mod.module_name = name |
|
|
|
mod.w_bit = mod_name_to_weight_width[name] |
|
if 'attn2' in name: |
|
if 'to_k' in name or 'to_v' in name: |
|
mod.bos = bos |
|
mod.bos_pre_computed = bos_dict[name] |
|
|
|
del mod_name_to_weight_width_copy[name] |
|
|
|
if len(mod_name_to_weight_width_copy): |
|
for name in mod_name_to_weight_width_copy.keys(): |
|
print(f"{name} not found in UNet!") |
|
raise RuntimeError("Not all keys in weight yaml map to a module in " |
|
"UNet.") |
|
|
|
|
|
if a_bit is None: |
|
return |
|
|
|
|
|
if a_bit == 8: |
|
mod_name_to_act_width = a8_mixed_precision_config |
|
else: |
|
raise RuntimeError("we only support int8 quantization now") |
|
|
|
mod_name_to_act_width_copy = {} |
|
for mod_name, bit_width in mod_name_to_act_width.items(): |
|
new_name = filter_mod_name_prefix(mod_name) |
|
mod_name_to_act_width_copy[new_name] = bit_width |
|
mod_name_to_act_width = mod_name_to_act_width_copy |
|
|
|
|
|
mod_name_to_act_width_copy = mod_name_to_act_width |
|
for name, mod in unet.named_modules(): |
|
if name in mod_name_to_act_width: |
|
a_bitwidth = mod_name_to_act_width[name] |
|
a_dtype = bw_to_dtype[a_bitwidth] |
|
act_preprocess = PlaceholderObserver.with_args(dtype=a_dtype) |
|
if hasattr(mod, 'qconfig') and mod.qconfig: |
|
assert isinstance(mod.qconfig, QConfig) |
|
mod.qconfig = QConfig(weight=mod.qconfig.weight, |
|
activation=act_preprocess) |
|
else: |
|
weight_process = PlaceholderObserver.with_args( |
|
dtype=torch.float16) |
|
mod.qconfig = QConfig(activation=act_preprocess, |
|
weight=weight_process) |
|
|
|
|
|
|
|
mod.a_bit = mod_name_to_act_width[name] |
|
|
|
del mod_name_to_act_width_copy[name] |
|
|
|
if len(mod_name_to_act_width_copy): |
|
for name in mod_name_to_act_width_copy.keys(): |
|
print(f"{name} not found in UNet!") |
|
raise RuntimeError("Not all keys in act yaml map to a module in " |
|
"UNet.") |
|
|
|
|
|
def convert_to_quantized(unet, ckpt): |
|
|
|
convert(unet, |
|
mapping={nn.Linear: QuantizedLinear, |
|
nn.Conv2d: QuantizedConv2d, |
|
|
|
|
|
}, |
|
inplace=True, |
|
ckpt=ckpt) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def quantize_per_tensor_uint4( |
|
input: torch.Tensor, scale, zero_point, |
|
): |
|
|
|
|
|
scale = scale.view(-1, *([1] * (len(input.shape) - 1))) |
|
zero_point = zero_point.view(-1, *([1] * (len(input.shape) - 1))) |
|
|
|
|
|
scale_inv = 1.0 / scale |
|
int_repr = torch.clamp(torch.round(input * scale_inv) + |
|
zero_point, 0, 15).to(torch.uint8) |
|
if len(input.shape) >= 4: |
|
assert input.shape[1] % 2 == 0 |
|
return (int_repr[:, ::2, ...] << 4 | int_repr[:, 1::2, ...]) |
|
assert input.shape[-1] % 2 == 0 |
|
return (int_repr[..., ::2] << 4 | int_repr[..., 1::2]) |
|
|
|
|
|
def unpack_uint4(input): |
|
shape = input.shape |
|
if len(shape) >= 4: |
|
packed_dim = 2 |
|
new_shape = (input.shape[0], input.shape[1]*2, *input.shape[2:]) |
|
else: |
|
packed_dim = -1 |
|
new_shape = (*input.shape[:-1], input.shape[-1]*2) |
|
first_elements = (input >> 4).to(torch.uint8) |
|
second_elements = (input & 0b1111).to(torch.uint8) |
|
return torch.stack([first_elements, second_elements], dim=packed_dim).view(new_shape) |
|
|
|
|
|
def dequantize_per_tensor_uint4( |
|
input, scale, zero_point, |
|
): |
|
|
|
scale = scale.view(-1, *([1] * (len(input.shape) - 1))) |
|
zero_point = zero_point.view(-1, *([1] * (len(input.shape) - 1))) |
|
|
|
input = unpack_uint4(input) |
|
return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale |
|
|
|
|
|
dtype_to_bw = { |
|
torch.quint8: 8, |
|
torch.quint4x2: 4, |
|
torch.quint2x4: 2, |
|
torch.float16: 16, |
|
} |
|
|
|
|
|
class QParam(namedtuple("QParam", ["qscheme", "dtype", "scales", "zero_points", "axis"], defaults=[torch.per_tensor_affine, torch.quint8, 1.0, 0.0, 0])): |
|
@property |
|
def zp_float(self): |
|
return self.scales * self.zero_points |
|
pass |
|
|
|
|
|
def create_qparams_from_dtype( |
|
dtype, |
|
device, |
|
is_channel_wise=False, |
|
num_kernels=None, |
|
ckpt=None, |
|
module_name=None, |
|
bit_width=0, |
|
quant_type=None, |
|
split=0, |
|
): |
|
|
|
if dtype == torch.float16: |
|
return None |
|
elif dtype in [torch.qint8, torch.quint8, torch.quint4x2]: |
|
if quant_type == 'weight': |
|
scales, zero_points, scales_0, zero_points_0 = get_quant_para(ckpt, |
|
bit_width, |
|
module_name, |
|
quant_type='weight', |
|
split=split, |
|
device=device) |
|
elif quant_type == 'act': |
|
scales, zero_points, scales_0, zero_points_0 = get_quant_para(ckpt, |
|
bit_width, |
|
module_name, |
|
quant_type='act', |
|
split=split, |
|
device=device) |
|
else: |
|
raise ValueError(f"Unsupported quantize dtype {dtype}") |
|
|
|
if is_channel_wise: |
|
assert num_kernels is not None |
|
qparam = QParam(qscheme=torch.per_channel_affine, |
|
scales=scales, zero_points=zero_points, |
|
dtype=dtype, axis=0) |
|
if split > 0: |
|
qparam_0 = QParam(qscheme=torch.per_channel_affine, |
|
scales=scales_0, zero_points=zero_points_0, |
|
dtype=dtype, axis=0) |
|
else: |
|
qparam_0 = None |
|
|
|
else: |
|
qparam = QParam(qscheme=torch.per_tensor_affine, |
|
scales=scales, zero_points=zero_points, |
|
dtype=dtype) |
|
|
|
if split > 0: |
|
qparam_0 = QParam(qscheme=torch.per_tensor_affine, |
|
scales=scales_0, zero_points=zero_points_0, |
|
dtype=dtype) |
|
else: |
|
qparam_0 = None |
|
|
|
return qparam, qparam_0 |
|
|
|
|
|
def quantize_from_qparams(x: torch.Tensor, qparams: QParam): |
|
if qparams.dtype == torch.quint4x2: |
|
|
|
|
|
|
|
return quantize_per_tensor_uint4(x, qparams.scales.to(x.device), qparams.zero_points.to(x.device)) |
|
|
|
if qparams.qscheme in [torch.per_tensor_affine]: |
|
scales = qparams.scales |
|
scales = scales.clone().detach().to(x.device) \ |
|
if isinstance(scales, torch.Tensor) \ |
|
else torch.tensor(scales, dtype=torch.float16, device=x.device) |
|
zps = qparams.zero_points |
|
zps = zps.clone().detach().to(x.device) \ |
|
if isinstance(zps, torch.Tensor) \ |
|
else torch.tensor(zps, dtype=torch.float16, device=x.device) |
|
|
|
|
|
x = x.to(torch.float32) |
|
x_quant = torch.quantize_per_tensor(x, scales, zps, qparams.dtype) |
|
elif qparams.qscheme in [torch.per_channel_affine]: |
|
scales = qparams.scales |
|
assert isinstance(scales, torch.Tensor) |
|
scales = scales.clone().detach().to(x.device) |
|
zps = qparams.zero_points |
|
assert isinstance(zps, torch.Tensor) |
|
zps = zps.clone().detach().to(x.device) |
|
assert qparams.axis < len(x.shape) |
|
|
|
x = x.to(torch.float32) |
|
|
|
|
|
|
|
|
|
|
|
x_quant = torch.quantize_per_channel(x, scales, zps, axis=qparams.axis, |
|
dtype=qparams.dtype) |
|
else: |
|
raise ValueError(f"Unknown qscheme {qparams.qscheme}") |
|
return x_quant |
|
|
|
|
|
def dequantize_to_float16_linear(x: torch.Tensor, qparams: QParam): |
|
if x.dtype == torch.float16: |
|
return x |
|
if x.dtype in [torch.quint8, torch.qint8]: |
|
return x.dequantize().to(torch.float32) |
|
elif x.dtype in [torch.int8]: |
|
scale = (qparams.scales.view(-1, * |
|
([1] * (len(x.shape) - 1)))).cuda().float() |
|
zero_points = (qparams.zero_points.view(-1, * |
|
([1] * (len(x.shape) - 1)))).cuda().float() |
|
|
|
x = scale*(x - zero_points) |
|
return x |
|
|
|
assert x.dtype == torch.uint8 |
|
return dequantize_per_tensor_uint4(x, qparams.scales.to(x.device), qparams.zero_points.to(x.device)).to(torch.float16) |
|
|
|
|
|
def dequantize_to_float16(x: torch.Tensor, qparams: QParam): |
|
if x.dtype == torch.float16: |
|
return x |
|
if x.dtype in [torch.quint8, torch.qint8]: |
|
return x.dequantize().to(torch.float16) |
|
elif x.dtype in [torch.int8]: |
|
scale = (qparams.scales.view(-1, *([1] * (len(x.shape) - 1)))).cuda() |
|
zero_points = (qparams.zero_points.view(-1, * |
|
([1] * (len(x.shape) - 1)))).cuda() |
|
|
|
x = scale*(x - zero_points) |
|
return x |
|
|
|
assert x.dtype == torch.uint8 |
|
return dequantize_per_tensor_uint4(x, qparams.scales.to(x.device), qparams.zero_points.to(x.device)).to(torch.float16) |
|
|
|
|
|
def linear_on_quantized_data( |
|
w_tensor: torch.Tensor = None, |
|
w_tensor_org: torch.Tensor = None, |
|
w_qparams: QParam = None, |
|
key_first_token: torch.Tensor = None, |
|
a_tensor: torch.Tensor = None, |
|
a_qparams: QParam = None, |
|
bias: Optional[torch.Tensor] = None, |
|
bos: bool = False, |
|
module_name=None, |
|
bos_pre_computed=None, |
|
|
|
|
|
) -> torch.Tensor: |
|
if not bos: |
|
|
|
if a_qparams is not None: |
|
out = gemm_cutlass(w_qparams, a_qparams, w_tensor, a_tensor, bias) |
|
return out |
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
a_tensor = dequantize_to_float16( |
|
a_tensor, a_qparams) if a_qparams is not None else a_tensor |
|
w_tensor = dequantize_to_float16(w_tensor, w_qparams) |
|
return F.linear(a_tensor, w_tensor, bias) |
|
|
|
else: |
|
print("apply bos!") |
|
|
|
|
|
|
|
|
|
|
|
out_0 = bos_pre_computed.cuda() |
|
|
|
|
|
|
|
|
|
out_1 = gemm_cutlass(w_qparams, a_qparams, w_tensor, a_tensor, bias) |
|
out_0 = out_0.expand(out_1.shape[0], -1, -1) |
|
|
|
return torch.cat([out_0, out_1], dim=1) |
|
|
|
|
|
def conv2d_on_quantized_data( |
|
w_tensor: torch.Tensor = None, |
|
w_tensor_0: torch.Tensor = None, |
|
w_qparams: QParam = None, |
|
w_qparams_0: QParam = None, |
|
a_tensor: torch.Tensor = None, |
|
a_tensor_0: torch.Tensor = None, |
|
a_qparams: QParam = None, |
|
a_qparams_0: QParam = None, |
|
bias: Optional[torch.Tensor] = None, |
|
stride=1, |
|
padding=0, |
|
dilation=1, |
|
groups=1, |
|
split=0 |
|
) -> torch.Tensor: |
|
|
|
if split == 0: |
|
if a_qparams is not None: |
|
out = conv_cutlass(w_qparams, a_qparams, w_tensor, |
|
a_tensor, bias, stride, padding, dilation, groups) |
|
return out |
|
|
|
else: |
|
a_tensor = dequantize_to_float16( |
|
a_tensor, a_qparams) if a_qparams is not None else a_tensor |
|
w_tensor = dequantize_to_float16(w_tensor, w_qparams) |
|
return F.conv2d(a_tensor, w_tensor, bias, stride, padding, dilation, groups) |
|
|
|
elif split > 0: |
|
if a_qparams is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = conv_cutlass(w_qparams, a_qparams, w_tensor, |
|
a_tensor, None, stride, padding, dilation, groups) |
|
out_0 = conv_cutlass(w_qparams_0, a_qparams_0, w_tensor_0, |
|
a_tensor_0, None, stride, padding, dilation, groups) |
|
|
|
shape = bias.size() |
|
bias = bias.reshape(1, shape[0], 1, 1) |
|
out = out + out_0 + bias |
|
|
|
else: |
|
weight = dequantize_to_float16(w_tensor, w_qparams) |
|
weight_0 = dequantize_to_float16(w_tensor_0, w_qparams_0) |
|
a_tensor = a_tensor |
|
w_tensor = torch.cat([weight, weight_0], dim=1) |
|
out = F.conv2d(a_tensor, w_tensor, bias, stride, |
|
padding, dilation, groups) |
|
|
|
|
|
return out |
|
|
|
|
|
def gemm_cutlass(w_qparams, a_qparams, w_tensor, a_tensor, bias): |
|
s_w = w_qparams.scales.cuda().float() |
|
s_a = a_qparams.scales.cuda().float() |
|
z_a = a_qparams.zero_points.cuda().float() |
|
zps_a = a_qparams.zp_float.cuda().float() |
|
|
|
a_int = a_tensor |
|
|
|
w_int = w_tensor.int_repr() |
|
|
|
output_ref = qlinear( |
|
a_int, |
|
w_int, |
|
s_w, |
|
s_a, |
|
z_a, |
|
bias |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("run gemm on tensor core") |
|
|
|
|
|
return output_ref |
|
|
|
|
|
def conv_cutlass(w_qparams, a_qparams, w_tensor, a_tensor, bias, stride, padding, dilation, groups): |
|
print("run qconv2d!") |
|
s_w = w_qparams.scales.cuda().to(torch.float32) |
|
s_a = a_qparams.scales.cuda().to(torch.float32) |
|
z_a = a_qparams.zero_points.cuda().to(torch.float32) |
|
zps_a = a_qparams.zp_float.cuda().to(torch.float32) |
|
|
|
a_int = a_tensor |
|
w_int = w_tensor.int_repr() |
|
|
|
a_int = a_int.to(memory_format=torch.channels_last) |
|
w_int = w_int.to(memory_format=torch.channels_last) |
|
|
|
if len(set(padding)) == 1: |
|
padding = padding[0] |
|
else: |
|
raise RuntimeError("the padding has different elements") |
|
if len(set(stride)) == 1: |
|
stride = stride[0] |
|
else: |
|
raise RuntimeError("the stride has different elements") |
|
|
|
output = qconv2d( |
|
a_int, |
|
w_int, |
|
s_w, |
|
s_a, |
|
z_a, |
|
bias, |
|
stride, |
|
padding,) |
|
|
|
return output |
|
|
|
|
|
def get_quant_para(ckpt, n_bit, module_name, quant_type, split=0, device=None): |
|
|
|
if split == 0: |
|
bit_idx = int(math.log2(n_bit)-1) |
|
|
|
if quant_type == 'weight': |
|
module_name = module_name + '.weight_quantizer' |
|
assert module_name in ckpt.keys() |
|
scales = ckpt[module_name]['delta_list'][bit_idx] |
|
|
|
zero_point = ckpt[module_name]['zero_point_list'][bit_idx] |
|
|
|
|
|
elif quant_type == 'act': |
|
module_name = module_name + '.act_quantizer' |
|
assert module_name in ckpt.keys() |
|
scales = ckpt[module_name]['delta_list'][bit_idx] |
|
|
|
zero_point = ckpt[module_name]['zero_point_list'][bit_idx] - 128 |
|
|
|
return scales.to(device), zero_point.to(device), None, None |
|
|
|
elif split > 0: |
|
bit_idx = int(math.log2(n_bit)-1) |
|
|
|
if quant_type == 'weight': |
|
module_name = module_name + '.weight_quantizer' |
|
assert module_name in ckpt.keys() |
|
scales = ckpt[module_name]['delta_list'][bit_idx] |
|
zero_point = ckpt[module_name]['zero_point_list'][bit_idx] |
|
|
|
module_name = module_name + '_0' |
|
assert module_name in ckpt.keys() |
|
scales_0 = ckpt[module_name]['delta_list'][bit_idx] |
|
zero_point_0 = ckpt[module_name]['zero_point_list'][bit_idx] |
|
|
|
|
|
elif quant_type == 'act': |
|
module_name = module_name + '.act_quantizer' |
|
|
|
assert module_name in ckpt.keys() |
|
scales = ckpt[module_name]['delta_list'][bit_idx] |
|
zero_point = ckpt[module_name]['zero_point_list'][bit_idx]-128 |
|
|
|
module_name = module_name + '_0' |
|
assert module_name in ckpt.keys() |
|
scales_0 = ckpt[module_name]['delta_list'][bit_idx] |
|
zero_point_0 = ckpt[module_name]['zero_point_list'][bit_idx]-128 |
|
|
|
return scales.to(device), zero_point.to(device), scales_0.to(device), zero_point_0.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QuantizedLinear(nn.Module): |
|
def __init__(self, in_features: int, out_features: int, bias: bool = True, |
|
device=None, w_qparams=None, a_qparams=None, module_name=None) -> None: |
|
super().__init__() |
|
self.module_name = module_name |
|
|
|
|
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.device = device |
|
self.w_qparams = w_qparams |
|
self.a_qparams = a_qparams |
|
if self.w_qparams is not None: |
|
self.register_buffer("weight_scales", self.w_qparams.scales) |
|
self.register_buffer("weight_zero_points", |
|
self.w_qparams.zero_points) |
|
if self.a_qparams is not None: |
|
self.register_buffer("act_scales", self.a_qparams.scales) |
|
self.register_buffer("act_zero_points", self.a_qparams.zero_points) |
|
|
|
@classmethod |
|
def from_float(cls, float_mod, split=0, ckpt=None): |
|
assert hasattr(float_mod, 'qconfig') and isinstance(float_mod.qconfig, |
|
QConfig) |
|
weight_process = float_mod.qconfig.weight() |
|
w_dtype = weight_process.dtype |
|
num_kernels = float_mod.weight.shape[0] |
|
device = float_mod.weight.device |
|
|
|
w_qparams, w_qparams_0 = create_qparams_from_dtype(dtype=w_dtype, |
|
device=device, |
|
is_channel_wise=True, |
|
num_kernels=num_kernels, |
|
ckpt=ckpt, |
|
module_name=float_mod.module_name, |
|
quant_type='weight', |
|
bit_width=float_mod.w_bit, |
|
split=split) |
|
|
|
act_process = float_mod.qconfig.activation() |
|
act_dtype = act_process.dtype |
|
|
|
if hasattr(float_mod, 'a_bit'): |
|
a_qparams, a_qparams_0 = create_qparams_from_dtype(dtype=act_dtype, |
|
device=device, |
|
is_channel_wise=False, |
|
num_kernels=num_kernels, |
|
ckpt=ckpt, |
|
module_name=float_mod.module_name, |
|
quant_type='act', |
|
bit_width=float_mod.a_bit, |
|
split=split) |
|
else: |
|
a_qparams = None |
|
a_qparams_0 = None |
|
|
|
new_mod = cls(float_mod.in_features, |
|
float_mod.out_features, |
|
float_mod.bias is not None, |
|
device=float_mod.weight.device, |
|
w_qparams=w_qparams, |
|
a_qparams=a_qparams, |
|
module_name=float_mod.module_name, |
|
) |
|
|
|
weight = float_mod.weight.detach() |
|
|
|
if 'attn2' in float_mod.module_name: |
|
if 'to_k' in float_mod.module_name or 'to_v' in float_mod.module_name: |
|
new_mod.bos = float_mod.bos |
|
new_mod.bos_pre_computed = float_mod.bos_pre_computed |
|
|
|
|
|
|
|
if w_qparams is not None: |
|
weight = quantize_from_qparams(weight, w_qparams) |
|
new_mod.register_buffer("weight", weight) |
|
if float_mod.bias is not None: |
|
bias = float_mod.bias.detach() |
|
new_mod.register_buffer("bias", bias) |
|
else: |
|
new_mod.bias = None |
|
return new_mod |
|
|
|
def _get_name(self): |
|
w_width = 16 if self.w_qparams is None else \ |
|
dtype_to_bw[self.w_qparams.dtype] |
|
a_width = 16 if self.a_qparams is None else \ |
|
dtype_to_bw[self.a_qparams.dtype] |
|
return f"QuantizedLinear(W({w_width})A({a_width}))" |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
if not hasattr(self, 'bos'): |
|
if self.a_qparams is not None and x.dtype == torch.float16: |
|
|
|
x = quantize_per_tensor(x, self.a_qparams.scales.cuda().float( |
|
), self.a_qparams.zero_points.cuda().float()) if x.dtype is not torch.int8 else x |
|
|
|
return linear_on_quantized_data(w_tensor=self.weight, w_qparams=self.w_qparams, a_tensor=x, |
|
a_qparams=self.a_qparams, bias=self.bias) |
|
else: |
|
if self.a_qparams is not None and x.dtype == torch.float16 and self.bos: |
|
|
|
|
|
x_0 = quantize_per_tensor(x[:, 1:, :], self.a_qparams.scales.cuda().float( |
|
), self.a_qparams.zero_points.cuda().float()) if x.dtype is not torch.int8 else x[:, 1:, :] |
|
|
|
|
|
|
|
|
|
result = linear_on_quantized_data(w_tensor=self.weight, w_qparams=self.w_qparams, a_tensor=x_0, |
|
a_qparams=self.a_qparams, bias=self.bias, bos=True, module_name=self.module_name, bos_pre_computed=self.bos_pre_computed) |
|
|
|
|
|
return result |
|
|
|
else: |
|
if self.a_qparams is not None and x.dtype == torch.float16: |
|
|
|
x = quantize_per_tensor(x, self.a_qparams.scales.cuda().float( |
|
), self.a_qparams.zero_points.cuda().float()) if x.dtype is not torch.int8 else x |
|
|
|
return linear_on_quantized_data(w_tensor=self.weight, w_qparams=self.w_qparams, a_tensor=x, |
|
a_qparams=self.a_qparams, bias=self.bias) |
|
|
|
|
|
class QuantizedConv2d(nn.Module): |
|
|
|
def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1, |
|
padding=0, dilation=1, groups=1, bias=True, |
|
device=None, |
|
w_qparams=None, w_qparams_0=None, a_qparams=None, a_qparams_0=None, module_name=None, split=0) -> None: |
|
super().__init__() |
|
|
|
self.module_name = module_name |
|
self.split = split |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.device = device |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.dilation = dilation |
|
self.groups = groups |
|
|
|
|
|
self.w_qparams = w_qparams |
|
self.w_qparams_0 = w_qparams_0 |
|
self.a_qparams = a_qparams |
|
self.a_qparams_0 = a_qparams_0 |
|
|
|
if self.w_qparams is not None: |
|
self.register_buffer("weight_scales", self.w_qparams.scales) |
|
self.register_buffer("weight_zero_points", |
|
self.w_qparams.zero_points) |
|
if self.w_qparams_0 is not None: |
|
self.register_buffer("weight_scales_0", self.w_qparams.scales) |
|
self.register_buffer("weight_zero_points_0", |
|
self.w_qparams.zero_points) |
|
|
|
if self.a_qparams is not None: |
|
self.register_buffer("act_scales", self.a_qparams.scales) |
|
self.register_buffer("act_zero_points", self.a_qparams.zero_points) |
|
if self.a_qparams_0 is not None: |
|
self.register_buffer("act_scales_0", self.a_qparams.scales) |
|
self.register_buffer("act_zero_points_0", |
|
self.a_qparams.zero_points) |
|
|
|
@classmethod |
|
def from_float(cls, float_mod, split=0, ckpt=None): |
|
|
|
assert hasattr(float_mod, 'qconfig') and isinstance(float_mod.qconfig, |
|
QConfig) |
|
weight_process = float_mod.qconfig.weight() |
|
w_dtype = weight_process.dtype |
|
num_kernels = float_mod.weight.shape[0] |
|
device = float_mod.weight.device |
|
|
|
|
|
|
|
|
|
w_qparams, w_qparams_0 = create_qparams_from_dtype(dtype=w_dtype, |
|
device=device, |
|
is_channel_wise=True, |
|
num_kernels=num_kernels, |
|
ckpt=ckpt, |
|
module_name=float_mod.module_name, |
|
quant_type='weight', |
|
bit_width=float_mod.w_bit, |
|
split=split) |
|
|
|
act_process = float_mod.qconfig.activation() |
|
act_dtype = act_process.dtype |
|
|
|
|
|
if hasattr(float_mod, 'a_bit'): |
|
|
|
a_qparams, a_qparams_0 = create_qparams_from_dtype(dtype=act_dtype, |
|
device=device, |
|
is_channel_wise=False, |
|
num_kernels=num_kernels, |
|
ckpt=ckpt, |
|
module_name=float_mod.module_name, |
|
quant_type='act', |
|
bit_width=float_mod.a_bit, |
|
split=split) |
|
else: |
|
a_qparams = None |
|
a_qparams_0 = None |
|
|
|
new_mod = cls(float_mod.in_channels, |
|
float_mod.out_channels, |
|
float_mod.kernel_size, |
|
float_mod.stride, |
|
float_mod.padding, |
|
float_mod.dilation, |
|
float_mod.groups, |
|
float_mod.bias is not None, |
|
device=float_mod.weight.device, |
|
|
|
w_qparams=w_qparams, |
|
w_qparams_0=w_qparams_0, |
|
a_qparams=a_qparams, |
|
a_qparams_0=a_qparams_0, |
|
|
|
module_name=float_mod.module_name, |
|
split=split |
|
) |
|
|
|
weight = float_mod.weight.detach() |
|
|
|
if split == 0: |
|
if w_qparams is not None: |
|
weight = quantize_from_qparams(weight, w_qparams) |
|
new_mod.register_buffer("weight", weight) |
|
if float_mod.bias is not None: |
|
bias = float_mod.bias.detach() |
|
new_mod.register_buffer("bias", bias) |
|
else: |
|
new_mod.bias = None |
|
|
|
|
|
elif split > 0: |
|
if w_qparams is not None: |
|
weight_0 = quantize_from_qparams( |
|
weight[:, :split, ...], w_qparams) |
|
weight_1 = quantize_from_qparams( |
|
weight[:, split:, ...], w_qparams_0) |
|
|
|
new_mod.register_buffer("weight", weight_0) |
|
new_mod.register_buffer("weight_0", weight_1) |
|
if float_mod.bias is not None: |
|
bias = float_mod.bias.detach() |
|
new_mod.register_buffer("bias", bias) |
|
else: |
|
new_mod.bias = None |
|
|
|
return new_mod |
|
|
|
def _get_name(self): |
|
w_width = 16 if self.w_qparams is None else \ |
|
dtype_to_bw[self.w_qparams.dtype] |
|
a_width = 16 if self.a_qparams is None else \ |
|
dtype_to_bw[self.a_qparams.dtype] |
|
return f"QuantizedConv2d(W({w_width})A({a_width}))" |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
if self.a_qparams is not None and x.dtype == torch.float16: |
|
if self.split == 0: |
|
|
|
x_0 = quantize_per_tensor(x, self.a_qparams.scales.cuda( |
|
).float(), self.a_qparams.zero_points.cuda().float()) |
|
|
|
return conv2d_on_quantized_data(w_tensor=self.weight, |
|
w_qparams=self.w_qparams, |
|
a_tensor=x_0, |
|
a_qparams=self.a_qparams, |
|
bias=self.bias, |
|
stride=self.stride, |
|
padding=self.padding, |
|
dilation=self.dilation, |
|
groups=self.groups, |
|
split=self.split) |
|
elif self.split > 0: |
|
|
|
|
|
x_0 = quantize_per_tensor(x[:, :self.split, :, :], self.a_qparams.scales.cuda( |
|
).float(), self.a_qparams.zero_points.cuda().float()) |
|
x_1 = quantize_per_tensor(x[:, self.split:, :, :], self.a_qparams_0.scales.cuda( |
|
).float(), self.a_qparams_0.zero_points.cuda().float()) |
|
|
|
return conv2d_on_quantized_data(w_tensor=self.weight, |
|
w_tensor_0=self.weight_0, |
|
w_qparams=self.w_qparams, |
|
w_qparams_0=self.w_qparams_0, |
|
|
|
a_tensor=x_0, |
|
a_tensor_0=x_1, |
|
a_qparams=self.a_qparams, |
|
a_qparams_0=self.a_qparams_0, |
|
|
|
bias=self.bias, |
|
stride=self.stride, |
|
padding=self.padding, |
|
dilation=self.dilation, |
|
groups=self.groups, |
|
split=self.split) |
|
else: |
|
if self.split == 0: |
|
return conv2d_on_quantized_data(w_tensor=self.weight, |
|
w_qparams=self.w_qparams, |
|
a_tensor=x, |
|
a_qparams=self.a_qparams, |
|
bias=self.bias, |
|
stride=self.stride, |
|
padding=self.padding, |
|
dilation=self.dilation, |
|
groups=self.groups, |
|
split=self.split) |
|
elif self.split > 0: |
|
return conv2d_on_quantized_data(w_tensor=self.weight, |
|
w_tensor_0=self.weight_0, |
|
w_qparams=self.w_qparams, |
|
w_qparams_0=self.w_qparams_0, |
|
|
|
a_tensor=x, |
|
a_tensor_0=None, |
|
a_qparams=self.a_qparams, |
|
a_qparams_0=self.a_qparams_0, |
|
|
|
bias=self.bias, |
|
stride=self.stride, |
|
padding=self.padding, |
|
dilation=self.dilation, |
|
groups=self.groups, |
|
split=self.split) |
|
|
|
|
|
|
|
|
|
|
|
def make_memory_friendly(bytes): |
|
|
|
MBs = bytes / (1024*1024) |
|
|
|
B = bytes % 1024 |
|
bytes = bytes // 1024 |
|
kB = bytes % 1024 |
|
bytes = bytes // 1024 |
|
MB = bytes % 1024 |
|
GB = bytes // 1024 |
|
|
|
return f"{GB} G {MB} M {B} {kB} K {B} Bytes ({MBs} MBs)" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
|
""" |
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
|
""" |
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
|
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
|
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
|
return noise_cfg |
|
|
|
|
|
def retrieve_timesteps( |
|
scheduler, |
|
num_inference_steps: Optional[int] = None, |
|
device: Optional[Union[str, torch.device]] = None, |
|
timesteps: Optional[List[int]] = None, |
|
**kwargs, |
|
): |
|
""" |
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
|
|
|
Args: |
|
scheduler (`SchedulerMixin`): |
|
The scheduler to get timesteps from. |
|
num_inference_steps (`int`): |
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, |
|
`timesteps` must be `None`. |
|
device (`str` or `torch.device`, *optional*): |
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
|
timesteps (`List[int]`, *optional*): |
|
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default |
|
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` |
|
must be `None`. |
|
|
|
Returns: |
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
|
second element is the number of inference steps. |
|
""" |
|
if timesteps is not None: |
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
if not accepts_timesteps: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" timestep schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
else: |
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
return timesteps, num_inference_steps |
|
|
|
|
|
class MixDQ_SDXLTurbo_Pipeline_W8A8( |
|
DiffusionPipeline, |
|
FromSingleFileMixin, |
|
StableDiffusionXLLoraLoaderMixin, |
|
TextualInversionLoaderMixin, |
|
IPAdapterMixin, |
|
): |
|
r""" |
|
Pipeline for text-to-image generation using Stable Diffusion XL. |
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
|
|
|
In addition the pipeline inherits the following loading methods: |
|
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] |
|
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] |
|
|
|
as well as the following saving methods: |
|
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] |
|
|
|
Args: |
|
vae ([`AutoencoderKL`]): |
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
|
text_encoder ([`CLIPTextModel`]): |
|
Frozen text-encoder. Stable Diffusion XL uses the text portion of |
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically |
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. |
|
text_encoder_2 ([` CLIPTextModelWithProjection`]): |
|
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of |
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), |
|
specifically the |
|
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) |
|
variant. |
|
tokenizer (`CLIPTokenizer`): |
|
Tokenizer of class |
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
|
tokenizer_2 (`CLIPTokenizer`): |
|
Second Tokenizer of class |
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. |
|
scheduler ([`SchedulerMixin`]): |
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of |
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. |
|
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): |
|
Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of |
|
`stabilityai/stable-diffusion-xl-base-1-0`. |
|
add_watermarker (`bool`, *optional*): |
|
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to |
|
watermark output images. If not defined, it will default to True if the package is installed, otherwise no |
|
watermarker will be used. |
|
""" |
|
|
|
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" |
|
_optional_components = [ |
|
"tokenizer", |
|
"tokenizer_2", |
|
"text_encoder", |
|
"text_encoder_2", |
|
"image_encoder", |
|
"feature_extractor", |
|
] |
|
_callback_tensor_inputs = [ |
|
"latents", |
|
"prompt_embeds", |
|
"negative_prompt_embeds", |
|
"add_text_embeds", |
|
"add_time_ids", |
|
"negative_pooled_prompt_embeds", |
|
"negative_add_time_ids", |
|
] |
|
|
|
def __init__( |
|
self, |
|
vae: AutoencoderKL, |
|
text_encoder: CLIPTextModel, |
|
text_encoder_2: CLIPTextModelWithProjection, |
|
tokenizer: CLIPTokenizer, |
|
tokenizer_2: CLIPTokenizer, |
|
unet: UNet2DConditionModel, |
|
scheduler: KarrasDiffusionSchedulers, |
|
image_encoder: CLIPVisionModelWithProjection = None, |
|
feature_extractor: CLIPImageProcessor = None, |
|
force_zeros_for_empty_prompt: bool = True, |
|
add_watermarker: Optional[bool] = None, |
|
): |
|
super().__init__() |
|
|
|
self.register_modules( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
text_encoder_2=text_encoder_2, |
|
tokenizer=tokenizer, |
|
tokenizer_2=tokenizer_2, |
|
unet=unet, |
|
scheduler=scheduler, |
|
image_encoder=image_encoder, |
|
feature_extractor=feature_extractor, |
|
) |
|
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) |
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
|
|
|
self.default_sample_size = self.unet.config.sample_size |
|
|
|
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() |
|
|
|
if add_watermarker: |
|
self.watermark = StableDiffusionXLWatermarker() |
|
else: |
|
self.watermark = None |
|
|
|
|
|
def enable_vae_slicing(self): |
|
r""" |
|
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to |
|
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. |
|
""" |
|
self.vae.enable_slicing() |
|
|
|
|
|
def disable_vae_slicing(self): |
|
r""" |
|
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to |
|
computing decoding in one step. |
|
""" |
|
self.vae.disable_slicing() |
|
|
|
|
|
def enable_vae_tiling(self): |
|
r""" |
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to |
|
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow |
|
processing larger images. |
|
""" |
|
self.vae.enable_tiling() |
|
|
|
|
|
def disable_vae_tiling(self): |
|
r""" |
|
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to |
|
computing decoding in one step. |
|
""" |
|
self.vae.disable_tiling() |
|
|
|
def encode_prompt( |
|
self, |
|
prompt: str, |
|
prompt_2: Optional[str] = None, |
|
device: Optional[torch.device] = None, |
|
num_images_per_prompt: int = 1, |
|
do_classifier_free_guidance: bool = True, |
|
negative_prompt: Optional[str] = None, |
|
negative_prompt_2: Optional[str] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
lora_scale: Optional[float] = None, |
|
clip_skip: Optional[int] = None, |
|
): |
|
r""" |
|
Encodes the prompt into text encoder hidden states. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
prompt to be encoded |
|
prompt_2 (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is |
|
used in both text-encoders |
|
device: (`torch.device`): |
|
torch device |
|
num_images_per_prompt (`int`): |
|
number of images that should be generated per prompt |
|
do_classifier_free_guidance (`bool`): |
|
whether to use classifier free guidance or not |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
|
less than `1`). |
|
negative_prompt_2 (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and |
|
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders |
|
prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. |
|
If not provided, pooled text embeddings will be generated from `prompt` input argument. |
|
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` |
|
input argument. |
|
lora_scale (`float`, *optional*): |
|
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. |
|
clip_skip (`int`, *optional*): |
|
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that |
|
the output of the pre-final layer will be used for computing the prompt embeddings. |
|
""" |
|
device = device or self._execution_device |
|
|
|
|
|
|
|
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): |
|
self._lora_scale = lora_scale |
|
|
|
|
|
if self.text_encoder is not None: |
|
if not USE_PEFT_BACKEND: |
|
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) |
|
else: |
|
scale_lora_layers(self.text_encoder, lora_scale) |
|
|
|
if self.text_encoder_2 is not None: |
|
if not USE_PEFT_BACKEND: |
|
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) |
|
else: |
|
scale_lora_layers(self.text_encoder_2, lora_scale) |
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
|
if prompt is not None: |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
|
|
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] |
|
text_encoders = ( |
|
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] |
|
) |
|
|
|
if prompt_embeds is None: |
|
prompt_2 = prompt_2 or prompt |
|
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 |
|
|
|
|
|
prompt_embeds_list = [] |
|
prompts = [prompt, prompt_2] |
|
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): |
|
if isinstance(self, TextualInversionLoaderMixin): |
|
prompt = self.maybe_convert_prompt(prompt, tokenizer) |
|
|
|
text_inputs = tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
text_input_ids = text_inputs.input_ids |
|
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids |
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( |
|
text_input_ids, untruncated_ids |
|
): |
|
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) |
|
logger.warning( |
|
"The following part of your input was truncated because CLIP can only handle sequences up to" |
|
f" {tokenizer.model_max_length} tokens: {removed_text}" |
|
) |
|
|
|
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) |
|
|
|
|
|
pooled_prompt_embeds = prompt_embeds[0] |
|
if clip_skip is None: |
|
prompt_embeds = prompt_embeds.hidden_states[-2] |
|
else: |
|
|
|
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] |
|
|
|
prompt_embeds_list.append(prompt_embeds) |
|
|
|
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) |
|
|
|
|
|
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt |
|
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: |
|
negative_prompt_embeds = torch.zeros_like(prompt_embeds) |
|
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) |
|
elif do_classifier_free_guidance and negative_prompt_embeds is None: |
|
negative_prompt = negative_prompt or "" |
|
negative_prompt_2 = negative_prompt_2 or negative_prompt |
|
|
|
|
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
negative_prompt_2 = ( |
|
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 |
|
) |
|
|
|
uncond_tokens: List[str] |
|
if prompt is not None and type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
else: |
|
uncond_tokens = [negative_prompt, negative_prompt_2] |
|
|
|
negative_prompt_embeds_list = [] |
|
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): |
|
if isinstance(self, TextualInversionLoaderMixin): |
|
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) |
|
|
|
max_length = prompt_embeds.shape[1] |
|
uncond_input = tokenizer( |
|
negative_prompt, |
|
padding="max_length", |
|
max_length=max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
negative_prompt_embeds = text_encoder( |
|
uncond_input.input_ids.to(device), |
|
output_hidden_states=True, |
|
) |
|
|
|
negative_pooled_prompt_embeds = negative_prompt_embeds[0] |
|
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] |
|
|
|
negative_prompt_embeds_list.append(negative_prompt_embeds) |
|
|
|
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) |
|
|
|
if self.text_encoder_2 is not None: |
|
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) |
|
else: |
|
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) |
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape |
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) |
|
|
|
if do_classifier_free_guidance: |
|
|
|
seq_len = negative_prompt_embeds.shape[1] |
|
|
|
if self.text_encoder_2 is not None: |
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) |
|
else: |
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) |
|
|
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
|
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( |
|
bs_embed * num_images_per_prompt, -1 |
|
) |
|
if do_classifier_free_guidance: |
|
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( |
|
bs_embed * num_images_per_prompt, -1 |
|
) |
|
|
|
if self.text_encoder is not None: |
|
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: |
|
|
|
unscale_lora_layers(self.text_encoder, lora_scale) |
|
|
|
if self.text_encoder_2 is not None: |
|
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: |
|
|
|
unscale_lora_layers(self.text_encoder_2, lora_scale) |
|
|
|
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds |
|
|
|
|
|
def encode_image(self, image, device, num_images_per_prompt): |
|
dtype = next(self.image_encoder.parameters()).dtype |
|
|
|
if not isinstance(image, torch.Tensor): |
|
image = self.feature_extractor(image, return_tensors="pt").pixel_values |
|
|
|
image = image.to(device=device, dtype=dtype) |
|
image_embeds = self.image_encoder(image).image_embeds |
|
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) |
|
|
|
uncond_image_embeds = torch.zeros_like(image_embeds) |
|
return image_embeds, uncond_image_embeds |
|
|
|
|
|
def prepare_extra_step_kwargs(self, generator, eta): |
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs["eta"] = eta |
|
|
|
|
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
if accepts_generator: |
|
extra_step_kwargs["generator"] = generator |
|
return extra_step_kwargs |
|
|
|
def check_inputs( |
|
self, |
|
prompt, |
|
prompt_2, |
|
height, |
|
width, |
|
callback_steps, |
|
negative_prompt=None, |
|
negative_prompt_2=None, |
|
prompt_embeds=None, |
|
negative_prompt_embeds=None, |
|
pooled_prompt_embeds=None, |
|
negative_pooled_prompt_embeds=None, |
|
callback_on_step_end_tensor_inputs=None, |
|
): |
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
|
|
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): |
|
raise ValueError( |
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
|
f" {type(callback_steps)}." |
|
) |
|
|
|
if callback_on_step_end_tensor_inputs is not None and not all( |
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs |
|
): |
|
raise ValueError( |
|
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" |
|
) |
|
|
|
if prompt is not None and prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
|
" only forward one of the two." |
|
) |
|
elif prompt_2 is not None and prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
|
" only forward one of the two." |
|
) |
|
elif prompt is None and prompt_embeds is None: |
|
raise ValueError( |
|
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." |
|
) |
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): |
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
|
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): |
|
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") |
|
|
|
if negative_prompt is not None and negative_prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" |
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
|
) |
|
elif negative_prompt_2 is not None and negative_prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" |
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
|
) |
|
|
|
if prompt_embeds is not None and negative_prompt_embeds is not None: |
|
if prompt_embeds.shape != negative_prompt_embeds.shape: |
|
raise ValueError( |
|
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" |
|
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" |
|
f" {negative_prompt_embeds.shape}." |
|
) |
|
|
|
if prompt_embeds is not None and pooled_prompt_embeds is None: |
|
raise ValueError( |
|
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." |
|
) |
|
|
|
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: |
|
raise ValueError( |
|
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." |
|
) |
|
|
|
|
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): |
|
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) |
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
|
|
if latents is None: |
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
else: |
|
latents = latents.to(device) |
|
|
|
|
|
latents = latents * self.scheduler.init_noise_sigma |
|
return latents |
|
|
|
def _get_add_time_ids( |
|
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None |
|
): |
|
add_time_ids = list(original_size + crops_coords_top_left + target_size) |
|
|
|
passed_add_embed_dim = ( |
|
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim |
|
) |
|
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features |
|
|
|
if expected_add_embed_dim != passed_add_embed_dim: |
|
raise ValueError( |
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." |
|
) |
|
|
|
add_time_ids = torch.tensor([add_time_ids], dtype=dtype) |
|
return add_time_ids |
|
|
|
|
|
def upcast_vae(self): |
|
dtype = self.vae.dtype |
|
self.vae.to(dtype=torch.float32) |
|
use_torch_2_0_or_xformers = isinstance( |
|
self.vae.decoder.mid_block.attentions[0].processor, |
|
( |
|
AttnProcessor2_0, |
|
XFormersAttnProcessor, |
|
LoRAXFormersAttnProcessor, |
|
LoRAAttnProcessor2_0, |
|
), |
|
) |
|
|
|
|
|
if use_torch_2_0_or_xformers: |
|
self.vae.post_quant_conv.to(dtype) |
|
self.vae.decoder.conv_in.to(dtype) |
|
self.vae.decoder.mid_block.to(dtype) |
|
|
|
|
|
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): |
|
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. |
|
|
|
The suffixes after the scaling factors represent the stages where they are being applied. |
|
|
|
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values |
|
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. |
|
|
|
Args: |
|
s1 (`float`): |
|
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to |
|
mitigate "oversmoothing effect" in the enhanced denoising process. |
|
s2 (`float`): |
|
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to |
|
mitigate "oversmoothing effect" in the enhanced denoising process. |
|
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. |
|
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. |
|
""" |
|
if not hasattr(self, "unet"): |
|
raise ValueError("The pipeline must have `unet` for using FreeU.") |
|
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) |
|
|
|
|
|
def disable_freeu(self): |
|
"""Disables the FreeU mechanism if enabled.""" |
|
self.unet.disable_freeu() |
|
|
|
|
|
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): |
|
""" |
|
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 |
|
|
|
Args: |
|
timesteps (`torch.Tensor`): |
|
generate embedding vectors at these timesteps |
|
embedding_dim (`int`, *optional*, defaults to 512): |
|
dimension of the embeddings to generate |
|
dtype: |
|
data type of the generated embeddings |
|
|
|
Returns: |
|
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` |
|
""" |
|
assert len(w.shape) == 1 |
|
w = w * 1000.0 |
|
|
|
half_dim = embedding_dim // 2 |
|
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) |
|
emb = w.to(dtype)[:, None] * emb[None, :] |
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
|
if embedding_dim % 2 == 1: |
|
emb = torch.nn.functional.pad(emb, (0, 1)) |
|
assert emb.shape == (w.shape[0], embedding_dim) |
|
return emb |
|
|
|
def quantize_unet( |
|
self, |
|
w_bit = None, |
|
a_bit = None, |
|
bos=True, |
|
|
|
): |
|
r""" |
|
This function helps quantize the UNet in the SDXL Pipeline |
|
Now we only support quantization with the setting W8A8 |
|
|
|
Args: |
|
w_config_path: (`str`): |
|
the path for mixed precision config of weight |
|
a_config_path: (`str`): |
|
the path for mixed precision config of activation |
|
ckpt_path: (`str`): |
|
the path for the checkpoint of quant para |
|
bos: (`bool`): |
|
if to use bos technique |
|
bos_dict_path: (`str`): |
|
the path for mixed precision config of weight |
|
|
|
""" |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
path = hf_hub_download( |
|
repo_id="Stein-Fun/mixdq_test", |
|
filename="bos_pre_computed.pt", |
|
revision="version_0", |
|
) |
|
bos_dict = torch.load(path, map_location='cpu') |
|
|
|
path = hf_hub_download( |
|
repo_id="Stein-Fun/mixdq_test", |
|
filename="quant_para_wsym_fp16.pt", |
|
revision="version_0", |
|
) |
|
ckpt = torch.load(path, map_location='cpu') |
|
|
|
register_qconfig_from_input_files( |
|
self.unet, |
|
|
|
w_bit = w_bit, |
|
a_bit = a_bit, |
|
bos=bos, |
|
bos_dict=bos_dict |
|
) |
|
convert_to_quantized(self.unet, ckpt) |
|
|
|
def run_for_test( |
|
self, |
|
device, |
|
prompt: str = "A black and white photo of an older man skiing.", |
|
batch_size: int = 1, |
|
output_type: str = "latent", |
|
run_pipeline: bool = False, |
|
path: str = "result.png" |
|
): |
|
r""" |
|
run for test the memory reduction after quantization on GPUs |
|
|
|
Args: |
|
device: (`torch.device`): |
|
torch device, 'CUDA' is supported only |
|
prompt: (`str` or `List[str]`, *optional*): |
|
prompt to be encoded |
|
batch_size: (`int`): |
|
the batch size of inputs |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generate image. Choose between |
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
|
batch_size: (`int`): |
|
the batch size of inputs |
|
run_pipeline: (`bool`): |
|
if to run the whole pipeline or just run the UNet |
|
path: (`str`): |
|
the path to save the output image |
|
""" |
|
|
|
if run_pipeline: |
|
self.to(device) |
|
else: |
|
self.unet.to(device) |
|
|
|
model_memory = torch.cuda.memory_allocated() |
|
print("Static (weights) memory usage:", |
|
make_memory_friendly(model_memory)) |
|
|
|
|
|
if run_pipeline: |
|
|
|
latents = self(prompt=[prompt]*batch_size, |
|
guidance_scale=0.0, |
|
num_inference_steps=2, |
|
output_type=output_type).images[0] |
|
else: |
|
sample_shape = ( |
|
batch_size * 1, |
|
self.unet.config.in_channels, |
|
self.unet.config.sample_size, |
|
self.unet.config.sample_size, |
|
) |
|
|
|
encoder_embedding_shape = ( |
|
batch_size * 1, |
|
77, |
|
2048, |
|
) |
|
|
|
|
|
example_sample = torch.rand(*sample_shape, device=device, |
|
dtype=torch.float16) |
|
example_embedding = torch.rand(*encoder_embedding_shape, |
|
device=device, dtype=torch.float16) |
|
timestep = torch.tensor(999., device=device) |
|
text_embeds = torch.rand(batch_size, 1280, device=device, |
|
dtype=torch.float16) |
|
time_ids = torch.tensor([[512., 512., 0., 0., 512., 512.]], dtype=torch.float16, |
|
device=device) |
|
time_ids = torch.concat([time_ids] * batch_size) |
|
with torch.no_grad(): |
|
|
|
latents = self.unet(sample=example_sample, |
|
timestep=timestep, |
|
encoder_hidden_states=example_embedding, |
|
added_cond_kwargs={ |
|
'time_ids': time_ids, |
|
'text_embeds': text_embeds |
|
}, |
|
return_dict=False)[0] |
|
|
|
|
|
|
|
peak_memory = torch.cuda.max_memory_allocated() |
|
print("Dynamic (acts) memory usage:", |
|
make_memory_friendly(peak_memory - model_memory)) |
|
print("Peak (total) memory usage:", make_memory_friendly(peak_memory)) |
|
|
|
if output_type == "pil": |
|
image = latents |
|
|
|
image.save(path) |
|
return latents |
|
|
|
@property |
|
def guidance_scale(self): |
|
return self._guidance_scale |
|
|
|
@property |
|
def guidance_rescale(self): |
|
return self._guidance_rescale |
|
|
|
@property |
|
def clip_skip(self): |
|
return self._clip_skip |
|
|
|
|
|
|
|
|
|
@property |
|
def do_classifier_free_guidance(self): |
|
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None |
|
|
|
@property |
|
def cross_attention_kwargs(self): |
|
return self._cross_attention_kwargs |
|
|
|
@property |
|
def denoising_end(self): |
|
return self._denoising_end |
|
|
|
@property |
|
def num_timesteps(self): |
|
return self._num_timesteps |
|
|
|
@torch.no_grad() |
|
@replace_example_docstring(EXAMPLE_DOC_STRING) |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
prompt_2: Optional[Union[str, List[str]]] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 50, |
|
timesteps: List[int] = None, |
|
denoising_end: Optional[float] = None, |
|
guidance_scale: float = 5.0, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
negative_prompt_2: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
ip_adapter_image: Optional[PipelineImageInput] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
guidance_rescale: float = 0.0, |
|
original_size: Optional[Tuple[int, int]] = None, |
|
crops_coords_top_left: Tuple[int, int] = (0, 0), |
|
target_size: Optional[Tuple[int, int]] = None, |
|
negative_original_size: Optional[Tuple[int, int]] = None, |
|
negative_crops_coords_top_left: Tuple[int, int] = (0, 0), |
|
negative_target_size: Optional[Tuple[int, int]] = None, |
|
clip_skip: Optional[int] = None, |
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
|
**kwargs, |
|
): |
|
r""" |
|
Function invoked when calling the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. |
|
instead. |
|
prompt_2 (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is |
|
used in both text-encoders |
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The height in pixels of the generated image. This is set to 1024 by default for the best results. |
|
Anything below 512 pixels won't work well for |
|
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) |
|
and checkpoints that are not specifically fine-tuned on low resolutions. |
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The width in pixels of the generated image. This is set to 1024 by default for the best results. |
|
Anything below 512 pixels won't work well for |
|
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) |
|
and checkpoints that are not specifically fine-tuned on low resolutions. |
|
num_inference_steps (`int`, *optional*, defaults to 50): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
timesteps (`List[int]`, *optional*): |
|
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument |
|
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is |
|
passed will be used. Must be in descending order. |
|
denoising_end (`float`, *optional*): |
|
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be |
|
completed before it is intentionally prematurely terminated. As a result, the returned sample will |
|
still retain a substantial amount of noise as determined by the discrete timesteps selected by the |
|
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a |
|
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image |
|
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) |
|
guidance_scale (`float`, *optional*, defaults to 5.0): |
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen |
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
|
usually at the expense of lower image quality. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
|
less than `1`). |
|
negative_prompt_2 (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and |
|
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
eta (`float`, *optional*, defaults to 0.0): |
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to |
|
[`schedulers.DDIMScheduler`], will be ignored for others. |
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) |
|
to make generation deterministic. |
|
latents (`torch.FloatTensor`, *optional*): |
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor will ge generated by sampling using the supplied random `generator`. |
|
prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. |
|
If not provided, pooled text embeddings will be generated from `prompt` input argument. |
|
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` |
|
input argument. |
|
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generate image. Choose between |
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead |
|
of a plain tuple. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
guidance_rescale (`float`, *optional*, defaults to 0.0): |
|
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are |
|
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of |
|
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). |
|
Guidance rescale factor should fix overexposure when using zero terminal SNR. |
|
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): |
|
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. |
|
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as |
|
explained in section 2.2 of |
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
|
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): |
|
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position |
|
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting |
|
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of |
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
|
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): |
|
For most cases, `target_size` should be set to the desired height and width of the generated image. If |
|
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in |
|
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
|
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): |
|
To negatively condition the generation process based on a specific image resolution. Part of SDXL's |
|
micro-conditioning as explained in section 2.2 of |
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more |
|
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. |
|
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): |
|
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's |
|
micro-conditioning as explained in section 2.2 of |
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more |
|
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. |
|
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): |
|
To negatively condition the generation process based on a target image resolution. It should be as same |
|
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of |
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more |
|
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. |
|
callback_on_step_end (`Callable`, *optional*): |
|
A function that calls at the end of each denoising steps during the inference. The function is called |
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, |
|
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by |
|
`callback_on_step_end_tensor_inputs`. |
|
callback_on_step_end_tensor_inputs (`List`, *optional*): |
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list |
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the |
|
`._callback_tensor_inputs` attribute of your pipeline class. |
|
|
|
Examples: |
|
|
|
Returns: |
|
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: |
|
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a |
|
`tuple`. When returning a tuple, the first element is a list with the generated images. |
|
""" |
|
|
|
callback = kwargs.pop("callback", None) |
|
callback_steps = kwargs.pop("callback_steps", None) |
|
|
|
if callback is not None: |
|
deprecate( |
|
"callback", |
|
"1.0.0", |
|
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", |
|
) |
|
if callback_steps is not None: |
|
deprecate( |
|
"callback_steps", |
|
"1.0.0", |
|
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", |
|
) |
|
|
|
|
|
height = height or self.default_sample_size * self.vae_scale_factor |
|
width = width or self.default_sample_size * self.vae_scale_factor |
|
|
|
original_size = original_size or (height, width) |
|
target_size = target_size or (height, width) |
|
|
|
|
|
self.check_inputs( |
|
prompt, |
|
prompt_2, |
|
height, |
|
width, |
|
callback_steps, |
|
negative_prompt, |
|
negative_prompt_2, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds, |
|
callback_on_step_end_tensor_inputs, |
|
) |
|
|
|
self._guidance_scale = guidance_scale |
|
self._guidance_rescale = guidance_rescale |
|
self._clip_skip = clip_skip |
|
self._cross_attention_kwargs = cross_attention_kwargs |
|
self._denoising_end = denoising_end |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
|
|
lora_scale = ( |
|
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None |
|
) |
|
|
|
( |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds, |
|
) = self.encode_prompt( |
|
prompt=prompt, |
|
prompt_2=prompt_2, |
|
device=device, |
|
num_images_per_prompt=num_images_per_prompt, |
|
do_classifier_free_guidance=self.do_classifier_free_guidance, |
|
negative_prompt=negative_prompt, |
|
negative_prompt_2=negative_prompt_2, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
|
lora_scale=lora_scale, |
|
clip_skip=self.clip_skip, |
|
) |
|
|
|
|
|
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) |
|
|
|
|
|
num_channels_latents = self.unet.config.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
add_text_embeds = pooled_prompt_embeds |
|
if self.text_encoder_2 is None: |
|
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) |
|
else: |
|
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim |
|
|
|
add_time_ids = self._get_add_time_ids( |
|
original_size, |
|
crops_coords_top_left, |
|
target_size, |
|
dtype=prompt_embeds.dtype, |
|
text_encoder_projection_dim=text_encoder_projection_dim, |
|
) |
|
if negative_original_size is not None and negative_target_size is not None: |
|
negative_add_time_ids = self._get_add_time_ids( |
|
negative_original_size, |
|
negative_crops_coords_top_left, |
|
negative_target_size, |
|
dtype=prompt_embeds.dtype, |
|
text_encoder_projection_dim=text_encoder_projection_dim, |
|
) |
|
else: |
|
negative_add_time_ids = add_time_ids |
|
|
|
if self.do_classifier_free_guidance: |
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
|
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) |
|
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) |
|
|
|
prompt_embeds = prompt_embeds.to(device) |
|
add_text_embeds = add_text_embeds.to(device) |
|
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) |
|
|
|
if ip_adapter_image is not None: |
|
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) |
|
if self.do_classifier_free_guidance: |
|
image_embeds = torch.cat([negative_image_embeds, image_embeds]) |
|
image_embeds = image_embeds.to(device) |
|
|
|
|
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
|
|
|
|
|
if ( |
|
self.denoising_end is not None |
|
and isinstance(self.denoising_end, float) |
|
and self.denoising_end > 0 |
|
and self.denoising_end < 1 |
|
): |
|
discrete_timestep_cutoff = int( |
|
round( |
|
self.scheduler.config.num_train_timesteps |
|
- (self.denoising_end * self.scheduler.config.num_train_timesteps) |
|
) |
|
) |
|
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) |
|
timesteps = timesteps[:num_inference_steps] |
|
|
|
|
|
timestep_cond = None |
|
if self.unet.config.time_cond_proj_dim is not None: |
|
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) |
|
timestep_cond = self.get_guidance_scale_embedding( |
|
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim |
|
).to(device=device, dtype=latents.dtype) |
|
|
|
self._num_timesteps = len(timesteps) |
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
|
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
|
|
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} |
|
if ip_adapter_image is not None: |
|
added_cond_kwargs["image_embeds"] = image_embeds |
|
noise_pred = self.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
timestep_cond=timestep_cond, |
|
cross_attention_kwargs=self.cross_attention_kwargs, |
|
added_cond_kwargs=added_cond_kwargs, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
if self.do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: |
|
|
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
|
|
|
if callback_on_step_end is not None: |
|
callback_kwargs = {} |
|
for k in callback_on_step_end_tensor_inputs: |
|
callback_kwargs[k] = locals()[k] |
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
|
|
|
latents = callback_outputs.pop("latents", latents) |
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) |
|
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) |
|
negative_pooled_prompt_embeds = callback_outputs.pop( |
|
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds |
|
) |
|
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) |
|
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
step_idx = i // getattr(self.scheduler, "order", 1) |
|
callback(step_idx, t, latents) |
|
|
|
|
|
|
|
|
|
if not output_type == "latent": |
|
|
|
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast |
|
|
|
if needs_upcasting: |
|
self.upcast_vae() |
|
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) |
|
|
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
|
|
|
|
|
if needs_upcasting: |
|
self.vae.to(dtype=torch.float16) |
|
else: |
|
image = latents |
|
|
|
if not output_type == "latent": |
|
|
|
if self.watermark is not None: |
|
image = self.watermark.apply_watermark(image) |
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type) |
|
|
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return (image,) |
|
|
|
return StableDiffusionXLPipelineOutput(images=image) |
|
|
|
|
|
|
|
|
|
a8_mixed_precision_config = \ |
|
{ |
|
'add_embedding.linear_1':8,'add_embedding.linear_2':8,'down_blocks.0.downsamplers.0.conv':8,'down_blocks.0.resnets.0.conv1':8,'down_blocks.0.resnets.0.time_emb_proj':8,'down_blocks.0.resnets.1.conv1':8,'down_blocks.0.resnets.1.conv2':8,'down_blocks.0.resnets.1.time_emb_proj':8,'down_blocks.1.attentions.0.proj_in':8,'down_blocks.1.attentions.0.proj_out':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v':8,'down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj':8,'down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v':8,'down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj':8,'down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2':8, |
|
'down_blocks.1.attentions.1.proj_in':8,'down_blocks.1.attentions.1.proj_out':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v':8,'down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj':8,'down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v':8,'down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj':8,'down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2':8,'down_blocks.1.downsamplers.0.conv':8,'down_blocks.1.resnets.0.conv1':8,'down_blocks.1.resnets.0.conv2':8,'down_blocks.1.resnets.0.conv_shortcut':8,'down_blocks.1.resnets.0.time_emb_proj':8,'down_blocks.1.resnets.1.conv1':8,'down_blocks.1.resnets.1.conv2':8,'down_blocks.1.resnets.1.time_emb_proj':8, |
|
'down_blocks.2.attentions.0.proj_in':8,'down_blocks.2.attentions.0.proj_out':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v':8, |
|
'down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.4.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.4.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_v':8, |
|
'down_blocks.2.attentions.0.transformer_blocks.5.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.5.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.6.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.6.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.7.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.7.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_v':8, |
|
'down_blocks.2.attentions.0.transformer_blocks.8.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.8.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.9.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.9.ff.net.2':8,'down_blocks.2.attentions.1.proj_in':8,'down_blocks.2.attentions.1.proj_out':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0':8, |
|
'down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_out.0':8, |
|
'down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.4.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.4.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.5.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.6.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_v':8, |
|
'down_blocks.2.attentions.1.transformer_blocks.7.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.8.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.9.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.9.ff.net.2':8,'down_blocks.2.resnets.0.conv1':8,'down_blocks.2.resnets.0.conv2':8,'down_blocks.2.resnets.0.conv_shortcut':8,'down_blocks.2.resnets.0.time_emb_proj':8,'down_blocks.2.resnets.1.conv1':8,'down_blocks.2.resnets.1.conv2':8,'down_blocks.2.resnets.1.time_emb_proj':8,'mid_block.attentions.0.proj_in':8,'mid_block.attentions.0.proj_out':8,'mid_block.attentions.0.transformer_blocks.0.attn1.to_k':8, |
|
'mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.0.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.0.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.0.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.0.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.0.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.0.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.1.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.1.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.1.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.1.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.1.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.1.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.1.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.2.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.2.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.2.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.2.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.2.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.2.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.2.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.3.attn1.to_k':8, |
|
'mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.3.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.3.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.3.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.3.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.3.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.3.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.4.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.4.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.4.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.4.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.4.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.4.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.4.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.4.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.4.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.4.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.5.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.5.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.5.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.5.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.5.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.5.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.5.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.5.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.5.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.5.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.6.attn1.to_k':8, |
|
'mid_block.attentions.0.transformer_blocks.6.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.6.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.6.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.6.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.6.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.6.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.6.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.6.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.6.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.7.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.7.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.7.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.7.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.7.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.7.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.7.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.7.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.7.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.7.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.8.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.8.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.8.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.8.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.8.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.8.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.8.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.8.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.8.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.8.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.9.attn1.to_k':8, |
|
'mid_block.attentions.0.transformer_blocks.9.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.9.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.9.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.9.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.9.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.9.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.9.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.9.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.9.ff.net.2':8,'mid_block.resnets.0.conv1':8,'mid_block.resnets.0.conv2':8,'mid_block.resnets.0.time_emb_proj':8,'mid_block.resnets.1.conv1':8,'mid_block.resnets.1.conv2':8,'mid_block.resnets.1.time_emb_proj':8,'time_embedding.linear_1':8,'time_embedding.linear_2':8,'up_blocks.0.attentions.0.proj_in':8,'up_blocks.0.attentions.0.proj_out':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0':8, |
|
'up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_out.0':8, |
|
'up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.4.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.4.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.5.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.5.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.6.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.6.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_out.0':8, |
|
'up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.7.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.7.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.8.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.8.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.9.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.9.ff.net.2':8,'up_blocks.0.attentions.1.proj_in':8,'up_blocks.0.attentions.1.proj_out':8, |
|
'up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2':8, |
|
'up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.4.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.4.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.5.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.5.ff.net.2':8, |
|
'up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.6.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.6.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.7.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.7.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.8.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.8.ff.net.2':8, |
|
'up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.9.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.9.ff.net.2':8,'up_blocks.0.attentions.2.proj_in':8,'up_blocks.0.attentions.2.proj_out':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_v':8, |
|
'up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_v':8, |
|
'up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_v':8, |
|
'up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2':8,'up_blocks.0.resnets.0.conv1':8,'up_blocks.0.resnets.0.conv2':8,'up_blocks.0.resnets.0.conv_shortcut':8,'up_blocks.0.resnets.0.time_emb_proj':8,'up_blocks.0.resnets.1.conv1':8,'up_blocks.0.resnets.1.conv2':8,'up_blocks.0.resnets.1.conv_shortcut':8,'up_blocks.0.resnets.1.time_emb_proj':8, |
|
'up_blocks.0.resnets.2.conv1':8,'up_blocks.0.resnets.2.conv2':8,'up_blocks.0.resnets.2.conv_shortcut':8,'up_blocks.0.resnets.2.time_emb_proj':8,'up_blocks.0.upsamplers.0.conv':8,'up_blocks.1.attentions.0.proj_in':8,'up_blocks.1.attentions.0.proj_out':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v':8,'up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj':8,'up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v':8,'up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj':8,'up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2':8,'up_blocks.1.attentions.1.proj_in':8,'up_blocks.1.attentions.1.proj_out':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k':8, |
|
'up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v':8,'up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj':8,'up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v':8,'up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj':8,'up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2':8,'up_blocks.1.attentions.2.proj_in':8,'up_blocks.1.attentions.2.proj_out':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v':8,'up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj':8, |
|
'up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_k':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_q':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_v':8,'up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj':8,'up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2':8,'up_blocks.1.resnets.0.conv1':8,'up_blocks.1.resnets.0.conv2':8,'up_blocks.1.resnets.0.conv_shortcut':8,'up_blocks.1.resnets.0.time_emb_proj':8,'up_blocks.1.resnets.1.conv1':8,'up_blocks.1.resnets.1.conv2':8,'up_blocks.1.resnets.1.conv_shortcut':8,'up_blocks.1.resnets.1.time_emb_proj':8,'up_blocks.1.resnets.2.conv1':8,'up_blocks.1.resnets.2.conv2':8,'up_blocks.1.resnets.2.conv_shortcut':8,'up_blocks.1.resnets.2.time_emb_proj':8,'up_blocks.1.upsamplers.0.conv':8,'up_blocks.2.resnets.0.conv1':8,'up_blocks.2.resnets.0.conv2':8,'up_blocks.2.resnets.0.conv_shortcut':8,'up_blocks.2.resnets.0.time_emb_proj':8,'up_blocks.2.resnets.1.conv1':8,'up_blocks.2.resnets.1.conv2':8, |
|
'up_blocks.2.resnets.1.conv_shortcut':8,'up_blocks.2.resnets.1.time_emb_proj':8,'up_blocks.2.resnets.2.conv1':8,'up_blocks.2.resnets.2.conv2':8,'up_blocks.2.resnets.2.time_emb_proj':8, |
|
} |
|
|
|
w8_uniform_config = \ |
|
{ |
|
'conv_in':8,'time_embedding.linear_1':8,'time_embedding.linear_2':8,'add_embedding.linear_1':8,'add_embedding.linear_2':8,'down_blocks.0.resnets.0.conv1':8,'down_blocks.0.resnets.0.time_emb_proj':8,'down_blocks.0.resnets.0.conv2':8,'down_blocks.0.resnets.1.conv1':8,'down_blocks.0.resnets.1.time_emb_proj':8,'down_blocks.0.resnets.1.conv2':8,'down_blocks.0.downsamplers.0.conv':8,'down_blocks.1.attentions.0.proj_in':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v':8,'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0':8,'down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj':8,'down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k':8,'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v':8, |
|
'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0':8,'down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj':8,'down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2':8,'down_blocks.1.attentions.0.proj_out':8,'down_blocks.1.attentions.1.proj_in':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v':8,'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0':8,'down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj':8,'down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v':8,'down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0':8,'down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj':8,'down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2':8,'down_blocks.1.attentions.1.proj_out':8,'down_blocks.1.resnets.0.conv1':8,'down_blocks.1.resnets.0.time_emb_proj':8,'down_blocks.1.resnets.0.conv2':8,'down_blocks.1.resnets.0.conv_shortcut':8, |
|
'down_blocks.1.resnets.1.conv1':8,'down_blocks.1.resnets.1.time_emb_proj':8,'down_blocks.1.resnets.1.conv2':8,'down_blocks.1.downsamplers.0.conv':8,'down_blocks.2.attentions.0.proj_in':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q':8, |
|
'down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.4.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.4.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_q':8, |
|
'down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.5.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.5.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.6.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.6.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.7.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.7.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_q':8, |
|
'down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.8.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.8.ff.net.2':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_q':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_k':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_v':8,'down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_out.0':8,'down_blocks.2.attentions.0.transformer_blocks.9.ff.net.0.proj':8,'down_blocks.2.attentions.0.transformer_blocks.9.ff.net.2':8,'down_blocks.2.attentions.0.proj_out':8,'down_blocks.2.attentions.1.proj_in':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v':8, |
|
'down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_v':8, |
|
'down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.4.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.4.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.5.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.5.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.6.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.6.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_v':8, |
|
'down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.7.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.7.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.8.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.8.ff.net.2':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_q':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_k':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_v':8,'down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_out.0':8,'down_blocks.2.attentions.1.transformer_blocks.9.ff.net.0.proj':8,'down_blocks.2.attentions.1.transformer_blocks.9.ff.net.2':8,'down_blocks.2.attentions.1.proj_out':8,'down_blocks.2.resnets.0.conv1':8,'down_blocks.2.resnets.0.time_emb_proj':8, |
|
'down_blocks.2.resnets.0.conv2':8,'down_blocks.2.resnets.0.conv_shortcut':8,'down_blocks.2.resnets.1.conv1':8,'down_blocks.2.resnets.1.time_emb_proj':8,'down_blocks.2.resnets.1.conv2':8,'up_blocks.0.attentions.0.proj_in':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0':8, |
|
'up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.4.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.4.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_out.0':8, |
|
'up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.5.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.5.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.6.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.6.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.7.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.7.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_out.0':8, |
|
'up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.8.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.8.ff.net.2':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_q':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_k':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_v':8,'up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_out.0':8,'up_blocks.0.attentions.0.transformer_blocks.9.ff.net.0.proj':8,'up_blocks.0.attentions.0.transformer_blocks.9.ff.net.2':8,'up_blocks.0.attentions.0.proj_out':8,'up_blocks.0.attentions.1.proj_in':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_k':8, |
|
'up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_k':8, |
|
'up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.4.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.4.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.5.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.5.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.6.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.6.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_k':8, |
|
'up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.7.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.7.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.8.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.8.ff.net.2':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_q':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_k':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_v':8,'up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_out.0':8,'up_blocks.0.attentions.1.transformer_blocks.9.ff.net.0.proj':8,'up_blocks.0.attentions.1.transformer_blocks.9.ff.net.2':8,'up_blocks.0.attentions.1.proj_out':8,'up_blocks.0.attentions.2.proj_in':8, |
|
'up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2':8, |
|
'up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2':8, |
|
'up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2':8, |
|
'up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_q':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_k':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_v':8,'up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0':8,'up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj':8,'up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2':8,'up_blocks.0.attentions.2.proj_out':8,'up_blocks.0.resnets.0.conv1':8,'up_blocks.0.resnets.0.time_emb_proj':8,'up_blocks.0.resnets.0.conv2':8,'up_blocks.0.resnets.0.conv_shortcut':8,'up_blocks.0.resnets.1.conv1':8,'up_blocks.0.resnets.1.time_emb_proj':8,'up_blocks.0.resnets.1.conv2':8,'up_blocks.0.resnets.1.conv_shortcut':8,'up_blocks.0.resnets.2.conv1':8,'up_blocks.0.resnets.2.time_emb_proj':8,'up_blocks.0.resnets.2.conv2':8,'up_blocks.0.resnets.2.conv_shortcut':8,'up_blocks.0.upsamplers.0.conv':8,'up_blocks.1.attentions.0.proj_in':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q':8, |
|
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v':8,'up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0':8,'up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj':8,'up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v':8,'up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0':8,'up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj':8,'up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2':8,'up_blocks.1.attentions.0.proj_out':8,'up_blocks.1.attentions.1.proj_in':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v':8,'up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0':8,'up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj':8,'up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v':8, |
|
'up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v':8,'up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0':8,'up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj':8,'up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2':8,'up_blocks.1.attentions.1.proj_out':8,'up_blocks.1.attentions.2.proj_in':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v':8,'up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0':8,'up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj':8,'up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_q':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_k':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_v':8,'up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0':8,'up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj':8,'up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2':8,'up_blocks.1.attentions.2.proj_out':8, |
|
'up_blocks.1.resnets.0.conv1':8,'up_blocks.1.resnets.0.time_emb_proj':8,'up_blocks.1.resnets.0.conv2':8,'up_blocks.1.resnets.0.conv_shortcut':8,'up_blocks.1.resnets.1.conv1':8,'up_blocks.1.resnets.1.time_emb_proj':8,'up_blocks.1.resnets.1.conv2':8,'up_blocks.1.resnets.1.conv_shortcut':8,'up_blocks.1.resnets.2.conv1':8,'up_blocks.1.resnets.2.time_emb_proj':8,'up_blocks.1.resnets.2.conv2':8,'up_blocks.1.resnets.2.conv_shortcut':8,'up_blocks.1.upsamplers.0.conv':8,'up_blocks.2.resnets.0.conv1':8,'up_blocks.2.resnets.0.time_emb_proj':8,'up_blocks.2.resnets.0.conv2':8,'up_blocks.2.resnets.0.conv_shortcut':8,'up_blocks.2.resnets.1.conv1':8,'up_blocks.2.resnets.1.time_emb_proj':8,'up_blocks.2.resnets.1.conv2':8,'up_blocks.2.resnets.1.conv_shortcut':8,'up_blocks.2.resnets.2.conv1':8,'up_blocks.2.resnets.2.time_emb_proj':8,'up_blocks.2.resnets.2.conv2':8,'up_blocks.2.resnets.2.conv_shortcut':8,'mid_block.attentions.0.proj_in':8,'mid_block.attentions.0.transformer_blocks.0.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.0.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.0.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0':8, |
|
'mid_block.attentions.0.transformer_blocks.0.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.0.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.0.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.0.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.1.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.1.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.1.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.1.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.1.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.1.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.1.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.2.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.2.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.2.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.2.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.2.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.2.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.2.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.3.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.3.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.3.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0':8, |
|
'mid_block.attentions.0.transformer_blocks.3.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.3.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.3.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.3.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.4.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.4.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.4.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.4.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.4.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.4.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.4.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.4.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.4.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.4.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.5.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.5.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.5.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.5.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.5.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.5.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.5.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.5.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.5.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.5.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.6.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.6.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.6.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.6.attn1.to_out.0':8, |
|
'mid_block.attentions.0.transformer_blocks.6.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.6.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.6.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.6.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.6.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.6.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.7.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.7.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.7.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.7.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.7.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.7.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.7.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.7.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.7.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.7.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.8.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.8.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.8.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.8.attn1.to_out.0':8,'mid_block.attentions.0.transformer_blocks.8.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.8.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.8.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.8.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.8.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.8.ff.net.2':8,'mid_block.attentions.0.transformer_blocks.9.attn1.to_q':8,'mid_block.attentions.0.transformer_blocks.9.attn1.to_k':8,'mid_block.attentions.0.transformer_blocks.9.attn1.to_v':8,'mid_block.attentions.0.transformer_blocks.9.attn1.to_out.0':8, |
|
'mid_block.attentions.0.transformer_blocks.9.attn2.to_q':8,'mid_block.attentions.0.transformer_blocks.9.attn2.to_k':8,'mid_block.attentions.0.transformer_blocks.9.attn2.to_v':8,'mid_block.attentions.0.transformer_blocks.9.attn2.to_out.0':8,'mid_block.attentions.0.transformer_blocks.9.ff.net.0.proj':8,'mid_block.attentions.0.transformer_blocks.9.ff.net.2':8,'mid_block.attentions.0.proj_out':8,'mid_block.resnets.0.conv1':8,'mid_block.resnets.0.time_emb_proj':8,'mid_block.resnets.0.conv2':8,'mid_block.resnets.1.conv1':8,'mid_block.resnets.1.time_emb_proj':8,'mid_block.resnets.1.conv2':8,'conv_out':8, |
|
} |