import argparse, torch, logging import packaging.version as pv from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image from sfast.compilers.diffusion_pipeline_compiler import (compile, CompilationConfig) logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s') logger = logging.getLogger('wfx') if pv.parse(torch.__version__) >= pv.parse('1.12.0'): torch.backends.cuda.matmul.allow_tf32 = True #torch.backends.cudnn.allow_tf32 = True # not sure... logger.info('matching torch version, enabling tf32') def parse_args(): args = argparse.ArgumentParser() args.add_argument('--disable-xformers', action='store_true', default=False) args.add_argument('--disable-triton', action='store_true', default=False) args.add_argument('--quantize-unet', action='store_true', default=False) args.add_argument('--model', type=str, required=True) args.add_argument('--custom-pipeline', type=str, default=None) args.add_argument('--compile-mode', default='sfast', type=str, choices=['sfast', 'torch', 'no-compile']) args.add_argument('--enable-cuda-graph', action='store_true', default=False) args.add_argument('--disable-prefer-lowp-gemm', action='store_true', default=False) return args.parse_args() def quantize_unet(m): from diffusers.utils import USE_PEFT_BACKEND assert USE_PEFT_BACKEND logger.info('PEFT backend detected, quantizing unet...') m = torch.quantization.quantize_dynamic( m, { torch.nn.Linear }, dtype=torch.qint8, inplace=True ) logger.info('unet successfully quantized') return m class WFX(): compiler_config: CompilationConfig.Default = CompilationConfig.Default() T2IPipeline: AutoPipelineForText2Image = None I2IPipeline: AutoPipelineForImage2Image = None def __init__(self) -> None: args = parse_args() self._check_optimization(args) def _check_optimization(self, args) -> None: logger.info(f'torch version: {torch.__version__}') if not args.disable_xformers: try: import xformers self.compiler_config.enable_xformers = True logger.info(f'xformers version: {xformers.__version__}') except ImportError: logger.warning('xformers not found, disabling xformers') if not args.disable_triton: try: import triton self.compiler_config.enable_triton = True logger.info(f'triton version: {triton.__version__}') except ImportError: logger.warning('triton not found, disabling triton') self.compiler_config.enable_cuda_graph = args.enable_cuda_graph if args.disable_prefer_lowp_gemm: self.compiler_config.prefer_lowp_gemm = False for key in self.compiler_config.__dict__: logger.info(f'cc - {key}: {self.compiler_config.__dict__[key]}') def load(self) -> None: args = parse_args() extra_kwargs = { 'torch_dtype': torch.float16, 'use_safetensors': True, 'requires_safety_checker': False, } if args.custom_pipeline is not None: logger.info(f'loading custom pipeline from "{args.custom_pipeline}"') extra_kwargs['custom_pipeline'] = args.custom_pipeline self.T2IPipeline = AutoPipelineForText2Image.from_pretrained(args.model, **extra_kwargs) self.T2IPipeline.safety_checker = None self.T2IPipeline.to(torch.device('cuda:0')) if args.quantize_unet: self.T2IPipeline.unet = quantize_unet(self.T2IPipeline.unet) logger.info(f'compiling pipeline in {args.compile_mode} mode...') if args.compile_mode == 'sfast': self.T2IPipeline = compile(self.T2IPipeline, self.compiler_config) elif args.compile_mode == 'torch': logger.info('compiling unet...') self.T2IPipeline.unet = torch.compile(self.T2IPipeline.unet, mode='max-autotune') logger.info('compiling vae...') self.T2IPipeline.vae = torch.compile(self.T2IPipeline.vae, mode='max-autotune') self.warmup() def warmup(self) -> None: warmed = 0 warmed_total = 5 warmup_kwargs = dict( prompt='a photo of a cat', height=768, width=512, num_inference_steps=30, generator=torch.Generator(device='cuda:0').manual_seed(0), ) if warmed < warmed_total: logger.info(f'warming up T2I pipeline...') for _ in range(warmed_total): begin = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) begin.record() self.T2IPipeline(**warmup_kwargs) end.record() torch.cuda.synchronize() elapsed_time = begin.elapsed_time(end) warmed += 1 logger.info(f'warmed {warmed}/{warmed_total} - {elapsed_time:.2f}ms') if __name__ == '__main__': wfx = WFX() wfx.load()