Spaces:
Build error
Build error
| """ | |
| Taken from https://huggingface.co/spaces/cbensimon/wan2-1-fast/ | |
| """ | |
| from typing import Any | |
| from typing import Callable | |
| from typing import ParamSpec | |
| import spaces | |
| import torch | |
| from torch.utils._pytree import tree_map_only | |
| from torchao.quantization import quantize_ | |
| from torchao.quantization import Float8DynamicActivationFloat8WeightConfig | |
| from diffusers import LTXConditionPipeline | |
| from optimization_utils import capture_component_call | |
| from optimization_utils import aoti_compile | |
| P = ParamSpec("P") | |
| # Sequence packing in LTX is a bit of a pain. | |
| # See: https://github.com/huggingface/diffusers/blob/c052791b5fe29ce8a308bf63dda97aa205b729be/src/diffusers/pipelines/ltx/pipeline_ltx.py#L420 | |
| # TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim("seq_len", min=4680, max=4680) | |
| # Unused currently as I don't know how to make the best use of it for LTX. | |
| # TRANSFORMER_DYNAMIC_SHAPES = { | |
| # "hidden_states": {1: TRANSFORMER_NUM_FRAMES_DIM}, | |
| # } | |
| INDUCTOR_CONFIGS = { | |
| "conv_1x1_as_mm": True, | |
| "epilogue_fusion": False, | |
| "coordinate_descent_tuning": True, | |
| "coordinate_descent_check_all_directions": True, | |
| "max_autotune": True, | |
| "triton.cudagraphs": True, | |
| } | |
| TRANSFORMER_SPATIAL_PATCH_SIZE = 1 | |
| TRANSFORMER_TEMPORAL_PATCH_SIZE = 1 | |
| VAE_SPATIAL_COMPRESSION_RATIO = 32 | |
| VAE_TEMPORAL_COMPRESSION_RATIO = 8 | |
| def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs): | |
| num_frames = kwargs.get("num_frames") | |
| height = kwargs.get("height") | |
| width = kwargs.get("width") | |
| latent_num_frames = (num_frames - 1) // VAE_TEMPORAL_COMPRESSION_RATIO + 1 | |
| latent_height = height // VAE_SPATIAL_COMPRESSION_RATIO | |
| latent_width = width // VAE_SPATIAL_COMPRESSION_RATIO | |
| def compile_transformer(): | |
| with capture_component_call(pipeline, "transformer") as call: | |
| pipeline(*args, **kwargs) | |
| dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs) | |
| # dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES | |
| quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig()) | |
| hidden_states: torch.Tensor = call.kwargs["hidden_states"] | |
| unpacked_hidden_states = LTXConditionPipeline._unpack_latents( | |
| hidden_states, | |
| latent_num_frames, | |
| latent_height, | |
| latent_width, | |
| TRANSFORMER_SPATIAL_PATCH_SIZE, | |
| TRANSFORMER_TEMPORAL_PATCH_SIZE, | |
| ) | |
| unpacked_hidden_states_transposed = unpacked_hidden_states.transpose(-1, -2).contiguous() | |
| if unpacked_hidden_states.shape[-1] > hidden_states.shape[-2]: | |
| hidden_states_landscape = unpacked_hidden_states | |
| hidden_states_portrait = unpacked_hidden_states_transposed | |
| else: | |
| hidden_states_landscape = unpacked_hidden_states_transposed | |
| hidden_states_portrait = unpacked_hidden_states | |
| hidden_states_landscape = LTXConditionPipeline._pack_latents( | |
| hidden_states_landscape, TRANSFORMER_SPATIAL_PATCH_SIZE, TRANSFORMER_TEMPORAL_PATCH_SIZE | |
| ) | |
| hidden_states_portrait = LTXConditionPipeline._pack_latents( | |
| hidden_states_portrait, TRANSFORMER_SPATIAL_PATCH_SIZE, TRANSFORMER_TEMPORAL_PATCH_SIZE | |
| ) | |
| exported_landscape = torch.export.export( | |
| mod=pipeline.transformer, | |
| args=call.args, | |
| kwargs=call.kwargs | {"hidden_states": hidden_states_landscape}, | |
| dynamic_shapes=dynamic_shapes, | |
| ) | |
| exported_portrait = torch.export.export( | |
| mod=pipeline.transformer, | |
| args=call.args, | |
| kwargs=call.kwargs | {"hidden_states": hidden_states_portrait}, | |
| dynamic_shapes=dynamic_shapes, | |
| ) | |
| compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS) | |
| compiled_portrait = aoti_compile(exported_portrait, INDUCTOR_CONFIGS) | |
| compiled_portrait.weights = ( | |
| compiled_landscape.weights | |
| ) # Avoid weights duplication when serializing back to main process | |
| return compiled_landscape, compiled_portrait | |
| compiled_landscape, compiled_portrait = compile_transformer() | |
| def combined_transformer(*args, **kwargs): | |
| hidden_states: torch.Tensor = kwargs["hidden_states"] | |
| unpacked_hidden_states = LTXConditionPipeline._unpack_latents( | |
| hidden_states, | |
| latent_num_frames, | |
| latent_height, | |
| latent_width, | |
| TRANSFORMER_SPATIAL_PATCH_SIZE, | |
| TRANSFORMER_TEMPORAL_PATCH_SIZE, | |
| ) | |
| if unpacked_hidden_states.shape[-1] > unpacked_hidden_states.shape[-2]: | |
| return compiled_landscape(*args, **kwargs) | |
| else: | |
| return compiled_portrait(*args, **kwargs) | |
| transformer_config = pipeline.transformer.config | |
| transformer_dtype = pipeline.transformer.dtype | |
| pipeline.transformer = combined_transformer | |
| pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue] | |
| pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue] | |