File size: 4,993 Bytes
			
			| d3b1ec0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | """
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import Int8WeightOnlyConfig
from optimization_utils import capture_component_call
from optimization_utils import aoti_compile
from optimization_utils import drain_module_parameters
P = ParamSpec('P')
# --- CORRECTED DYNAMIC SHAPING ---
# VAE temporal scale factor is 1, latent_frames = num_frames. Range is [8, 81].
LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
# The transformer has a patch_size of (1, 2, 2), which means the input latent height and width
# are effectively divided by 2. This creates constraints that fail if the symbolic tracer
# assumes odd numbers are possible.
#
# To solve this, we define the dynamic dimension for the *patched* (i.e., post-division) size,
# and then express the input shape as 2 * this dimension. This mathematically guarantees
# to the compiler that the input latent dimensions are always even, satisfying the constraints.
# App range for pixel dimensions: [480, 832]. VAE scale factor is 8.
# Latent dimension range: [480/8, 832/8] = [60, 104].
# Patched latent dimension range: [60/2, 104/2] = [30, 52].
LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)
# Now, we define the dynamic shapes for the transformer's `hidden_states` input,
# which has the shape (batch_size, channels, num_frames, height, width).
TRANSFORMER_DYNAMIC_SHAPES = {
    'hidden_states': {
        2: LATENT_FRAMES_DIM,
        3: 2 * LATENT_PATCHED_HEIGHT_DIM, # Guarantees even height
        4: 2 * LATENT_PATCHED_WIDTH_DIM,  # Guarantees even width
    },
}
# --- END OF CORRECTION ---
INDUCTOR_CONFIGS = {
    'conv_1x1_as_mm': True,
    'epilogue_fusion': False,
    'coordinate_descent_tuning': True,
    'coordinate_descent_check_all_directions': True,
    'max_autotune': True,
    'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
    @spaces.GPU(duration=1500)
    def compile_transformer():
        
        # This LoRA fusion part remains the same
        pipeline.load_lora_weights(
            "Kijai/WanVideo_comfy", 
            weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", 
            adapter_name="lightx2v"
        )
        kwargs_lora = {}
        kwargs_lora["load_into_transformer_2"] = True
        pipeline.load_lora_weights(
            "Kijai/WanVideo_comfy", 
            weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", 
            adapter_name="lightx2v_2", **kwargs_lora
        )
        pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
        pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
        pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
        pipeline.unload_lora_weights()
        
        # Capture a single call to get the args/kwargs structure
        with capture_component_call(pipeline, 'transformer') as call:
            pipeline(*args, **kwargs)
        
        dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
        dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
        # Quantization remains the same
        quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
        quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
        
        # --- SIMPLIFIED COMPILATION ---
        
        exported_1 = torch.export.export(
            mod=pipeline.transformer,
            args=call.args,
            kwargs=call.kwargs,
            dynamic_shapes=dynamic_shapes,
        )
        
        exported_2 = torch.export.export(
            mod=pipeline.transformer_2,
            args=call.args,
            kwargs=call.kwargs,
            dynamic_shapes=dynamic_shapes,
        )
        compiled_1 = aoti_compile(exported_1, INDUCTOR_CONFIGS)
        compiled_2 = aoti_compile(exported_2, INDUCTOR_CONFIGS)
        
        # Return the two compiled models
        return compiled_1, compiled_2
    # Quantize text encoder (same as before)
    quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
    
    # Get the two dynamically-shaped compiled models
    compiled_transformer_1, compiled_transformer_2 = compile_transformer()
    # --- SIMPLIFIED ASSIGNMENT ---
    
    pipeline.transformer.forward = compiled_transformer_1
    drain_module_parameters(pipeline.transformer)
    pipeline.transformer_2.forward = compiled_transformer_2
    drain_module_parameters(pipeline.transformer_2) | 
 
			
