File size: 5,360 Bytes
877e614
 
 
 
 
 
 
 
 
 
 
 
9095ff2
877e614
 
 
 
 
 
 
 
 
 
d6b300f
fdc9e50
 
877e614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc9e50
 
 
 
877e614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c4b798
877e614
 
 
 
d6b300f
 
 
 
 
 
 
 
877e614
 
 
 
753e002
750d549
753e002
877e614
 
a299603
877e614
 
 
 
 
753e002
 
fb27579
7e20e0a
 
750d549
 
 
 
 
 
 
 
fb27579
750d549
fb27579
877e614
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse, torch, logging
import packaging.version as pv
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
from sfast.compilers.diffusion_pipeline_compiler import (compile, CompilationConfig)


logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('wfx')


if pv.parse(torch.__version__) >= pv.parse('1.12.0'):
    torch.backends.cuda.matmul.allow_tf32 = True
    #torch.backends.cudnn.allow_tf32 = True # not sure...
    logger.info('matching torch version, enabling tf32')


def parse_args():
    args = argparse.ArgumentParser()
    args.add_argument('--disable-xformers', action='store_true', default=False)
    args.add_argument('--disable-triton', action='store_true', default=False)
    args.add_argument('--quantize-unet', action='store_true', default=False)
    args.add_argument('--model', type=str, required=True)
    args.add_argument('--custom-pipeline', type=str, default=None)
    args.add_argument('--compile-mode', default='sfast', type=str, choices=['sfast', 'torch', 'no-compile'])
    args.add_argument('--enable-cuda-graph', action='store_true', default=False)
    args.add_argument('--disable-prefer-lowp-gemm', action='store_true', default=False)
    return args.parse_args()

def quantize_unet(m):
    from diffusers.utils import USE_PEFT_BACKEND
    assert USE_PEFT_BACKEND
    
    logger.info('PEFT backend detected, quantizing unet...')
    
    m = torch.quantization.quantize_dynamic(
        m, { torch.nn.Linear },
        dtype=torch.qint8,
        inplace=True
    )
    
    logger.info('unet successfully quantized')
    return m


class WFX():
    compiler_config: CompilationConfig.Default = CompilationConfig.Default()
    T2IPipeline: AutoPipelineForText2Image = None
    I2IPipeline: AutoPipelineForImage2Image = None

    def __init__(self) -> None:
        args = parse_args()
        self._check_optimization(args)
    
    def _check_optimization(self, args) -> None:
        logger.info(f'torch version: {torch.__version__}')
        
        if not args.disable_xformers:
            try:
                import xformers
                self.compiler_config.enable_xformers = True
                logger.info(f'xformers version: {xformers.__version__}')
            except ImportError:
                logger.warning('xformers not found, disabling xformers')
        
        if not args.disable_triton:
            try:
                import triton
                self.compiler_config.enable_triton = True
                logger.info(f'triton version: {triton.__version__}')
            except ImportError:
                logger.warning('triton not found, disabling triton')
        
        self.compiler_config.enable_cuda_graph = args.enable_cuda_graph
        
        if args.disable_prefer_lowp_gemm:
            self.compiler_config.prefer_lowp_gemm = False
        
        for key in self.compiler_config.__dict__:
            logger.info(f'cc - {key}: {self.compiler_config.__dict__[key]}')

    def load(self) -> None:
        args = parse_args()
        extra_kwargs = {
            'torch_dtype': torch.float16,
            'use_safetensors': True,
            'requires_safety_checker': False,
        }
        
        if args.custom_pipeline is not None:
            logger.info(f'loading custom pipeline from "{args.custom_pipeline}"')
            extra_kwargs['custom_pipeline'] = args.custom_pipeline
        
        self.T2IPipeline = AutoPipelineForText2Image.from_pretrained(args.model, **extra_kwargs)
        self.T2IPipeline.safety_checker = None
        self.T2IPipeline.to(torch.device('cuda:0'))
        
        if args.quantize_unet:
            self.T2IPipeline.unet = quantize_unet(self.T2IPipeline.unet)
        
        logger.info(f'compiling pipeline in {args.compile_mode} mode...')
        if args.compile_mode == 'sfast':
            self.T2IPipeline = compile(self.T2IPipeline, self.compiler_config)
        elif args.compile_mode == 'torch':
            logger.info('compiling unet...')
            self.T2IPipeline.unet = torch.compile(self.T2IPipeline.unet, mode='max-autotune')
            logger.info('compiling vae...')
            self.T2IPipeline.vae = torch.compile(self.T2IPipeline.vae, mode='max-autotune')
        
        self.warmup()
    
    def warmup(self) -> None:
        warmed = 0
        warmed_total = 5
        
        warmup_kwargs = dict(
            prompt='a photo of a cat',
            height=768,
            width=512,
            num_inference_steps=30,
            generator=torch.Generator(device='cuda:0').manual_seed(0),
        )
        
        if warmed < warmed_total:
            logger.info(f'warming up T2I pipeline...')
            for _ in range(warmed_total):
                begin = torch.cuda.Event(enable_timing=True)
                end = torch.cuda.Event(enable_timing=True)
                
                begin.record()
                self.T2IPipeline(**warmup_kwargs)
                end.record()
                
                torch.cuda.synchronize()
                elapsed_time = begin.elapsed_time(end)
                
                warmed += 1
                logger.info(f'warmed {warmed}/{warmed_total} - {elapsed_time:.2f}ms')
            


if __name__ == '__main__':
    wfx = WFX()
    wfx.load()