File size: 1,878 Bytes
3df4fd5
 
 
b63cd34
 
 
 
3df4fd5
 
fc3f0ed
1d06ec0
 
3df4fd5
d00873b
 
dfac6b3
3df4fd5
 
b63cd34
288103a
3df4fd5
3af4a0e
 
 
 
 
 
 
b63cd34
 
 
 
 
 
 
 
288103a
 
b63cd34
288103a
3df4fd5
 
318b03c
b63cd34
 
318b03c
fc3f0ed
3af4a0e
318b03c
3df4fd5
318b03c
1d06ec0
 
318b03c
 
 
 
 
 
 
0dc2e9f
3df4fd5
 
35fad5f
b63cd34
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
"""

from typing import Any
from typing import Callable
from typing import ParamSpec

import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig

from optimization_utils import capture_component_call
from optimization_utils import aoti_compile
from optimization_utils import cudagraph


P = ParamSpec('P')


TRANSFORMER_HIDDEN_DIM = torch.export.Dim('hidden', min=4096, max=8212)

TRANSFORMER_DYNAMIC_SHAPES = {
    'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
    'img_ids': {0: TRANSFORMER_HIDDEN_DIM},
}

INDUCTOR_CONFIGS = {
    'conv_1x1_as_mm': True,
    'epilogue_fusion': False,
    'coordinate_descent_tuning': True,
    'coordinate_descent_check_all_directions': True,
    'max_autotune': True,
    'triton.cudagraphs': True,
}


def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):

    @spaces.GPU(duration=1500)
    def compile_transformer():

        with capture_component_call(pipeline, 'transformer') as call:
            pipeline(*args, **kwargs)

        dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
        dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES

        pipeline.transformer.fuse_qkv_projections()

        quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
        
        exported = torch.export.export(
            mod=pipeline.transformer,
            args=call.args,
            kwargs=call.kwargs,
            dynamic_shapes=dynamic_shapes,
        )

        return aoti_compile(exported, INDUCTOR_CONFIGS)

    transformer_config = pipeline.transformer.config
    pipeline.transformer = compile_transformer()
    pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]