wfx / wfx.py
zwv9's picture
args
fdc9e50
import argparse, torch, logging
import packaging.version as pv
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
from sfast.compilers.diffusion_pipeline_compiler import (compile, CompilationConfig)
logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('wfx')
if pv.parse(torch.__version__) >= pv.parse('1.12.0'):
torch.backends.cuda.matmul.allow_tf32 = True
#torch.backends.cudnn.allow_tf32 = True # not sure...
logger.info('matching torch version, enabling tf32')
def parse_args():
args = argparse.ArgumentParser()
args.add_argument('--disable-xformers', action='store_true', default=False)
args.add_argument('--disable-triton', action='store_true', default=False)
args.add_argument('--quantize-unet', action='store_true', default=False)
args.add_argument('--model', type=str, required=True)
args.add_argument('--custom-pipeline', type=str, default=None)
args.add_argument('--compile-mode', default='sfast', type=str, choices=['sfast', 'torch', 'no-compile'])
args.add_argument('--enable-cuda-graph', action='store_true', default=False)
args.add_argument('--disable-prefer-lowp-gemm', action='store_true', default=False)
return args.parse_args()
def quantize_unet(m):
from diffusers.utils import USE_PEFT_BACKEND
assert USE_PEFT_BACKEND
logger.info('PEFT backend detected, quantizing unet...')
m = torch.quantization.quantize_dynamic(
m, { torch.nn.Linear },
dtype=torch.qint8,
inplace=True
)
logger.info('unet successfully quantized')
return m
class WFX():
compiler_config: CompilationConfig.Default = CompilationConfig.Default()
T2IPipeline: AutoPipelineForText2Image = None
I2IPipeline: AutoPipelineForImage2Image = None
def __init__(self) -> None:
args = parse_args()
self._check_optimization(args)
def _check_optimization(self, args) -> None:
logger.info(f'torch version: {torch.__version__}')
if not args.disable_xformers:
try:
import xformers
self.compiler_config.enable_xformers = True
logger.info(f'xformers version: {xformers.__version__}')
except ImportError:
logger.warning('xformers not found, disabling xformers')
if not args.disable_triton:
try:
import triton
self.compiler_config.enable_triton = True
logger.info(f'triton version: {triton.__version__}')
except ImportError:
logger.warning('triton not found, disabling triton')
self.compiler_config.enable_cuda_graph = args.enable_cuda_graph
if args.disable_prefer_lowp_gemm:
self.compiler_config.prefer_lowp_gemm = False
for key in self.compiler_config.__dict__:
logger.info(f'cc - {key}: {self.compiler_config.__dict__[key]}')
def load(self) -> None:
args = parse_args()
extra_kwargs = {
'torch_dtype': torch.float16,
'use_safetensors': True,
'requires_safety_checker': False,
}
if args.custom_pipeline is not None:
logger.info(f'loading custom pipeline from "{args.custom_pipeline}"')
extra_kwargs['custom_pipeline'] = args.custom_pipeline
self.T2IPipeline = AutoPipelineForText2Image.from_pretrained(args.model, **extra_kwargs)
self.T2IPipeline.safety_checker = None
self.T2IPipeline.to(torch.device('cuda:0'))
if args.quantize_unet:
self.T2IPipeline.unet = quantize_unet(self.T2IPipeline.unet)
logger.info(f'compiling pipeline in {args.compile_mode} mode...')
if args.compile_mode == 'sfast':
self.T2IPipeline = compile(self.T2IPipeline, self.compiler_config)
elif args.compile_mode == 'torch':
logger.info('compiling unet...')
self.T2IPipeline.unet = torch.compile(self.T2IPipeline.unet, mode='max-autotune')
logger.info('compiling vae...')
self.T2IPipeline.vae = torch.compile(self.T2IPipeline.vae, mode='max-autotune')
self.warmup()
def warmup(self) -> None:
warmed = 0
warmed_total = 5
warmup_kwargs = dict(
prompt='a photo of a cat',
height=768,
width=512,
num_inference_steps=30,
generator=torch.Generator(device='cuda:0').manual_seed(0),
)
if warmed < warmed_total:
logger.info(f'warming up T2I pipeline...')
for _ in range(warmed_total):
begin = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
begin.record()
self.T2IPipeline(**warmup_kwargs)
end.record()
torch.cuda.synchronize()
elapsed_time = begin.elapsed_time(end)
warmed += 1
logger.info(f'warmed {warmed}/{warmed_total} - {elapsed_time:.2f}ms')
if __name__ == '__main__':
wfx = WFX()
wfx.load()