cbensimon's picture
cbensimon HF Staff
Update optimization.py
c628e2d verified
raw
history blame
1.92 kB
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
from spaces.zero.torch.aoti import ZeroGPUWeights
from fa3 import FlashFusedFluxAttnProcessor3_0
P = ParamSpec('P')
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
blocks_A = pipeline.transformer.transformer_blocks
blocks_B = pipeline.transformer.single_transformer_blocks
@spaces.GPU(duration=1500)
def compile_transformer_block_AB():
with spaces.aoti_capture(blocks_A[0]) as call_A:
pipeline(*args, **kwargs)
with spaces.aoti_capture(blocks_B[0]) as call_B:
pipeline(*args, **kwargs)
exported_A = torch.export.export(
mod=blocks_A[0],
args=call.args,
kwargs=call.kwargs,
)
exported_B = torch.export.export(
mod=blocks_B[0],
args=call.args,
kwargs=call.kwargs,
)
return (
spaces.aoti_compile(exported_A, INDUCTOR_CONFIGS).archive_file,
spaces.aoti_compile(exported_B, INDUCTOR_CONFIGS).archive_file,
)
pipeline.transformer.fuse_qkv_projections()
pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
archive_file_A, archive_file_B = compile_transformer_block_AB()
for blocks, archive_file in zip((blocks_A, blocks_B), (archive_file_A, archive_file_B)):
for block in blocks:
weights = ZeroGPUWeights(block.state_dict())
block.forward = ZeroGPUCompiledModel(archive_file, weights)