test / cli /model-jit.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
#!/usr/bin/env python
import os
import time
import functools
import argparse
import logging
import warnings
from dataclasses import dataclass
logging.getLogger("DeepSpeed").disabled = True
warnings.filterwarnings(action="ignore", category=FutureWarning)
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
import torch
import diffusers
n_warmup = 5
n_traces = 10
n_runs = 100
args = {}
pipe = None
log = logging.getLogger("sd")
def setup_logging():
from rich.theme import Theme
from rich.logging import RichHandler
from rich.console import Console
from rich.traceback import install
log.setLevel(logging.DEBUG)
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({ "traceback.border": "black", "traceback.border.syntax_error": "black", "inspect.value.border": "black" }))
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=logging.DEBUG, console=console)
rh.setLevel(logging.DEBUG)
log.addHandler(rh)
logging.getLogger("diffusers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)
warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning)
install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
def generate_inputs():
if args.type == 'sd15':
sample = torch.randn(2, 4, 64, 64).half().cuda()
timestep = torch.rand(1).half().cuda() * 999
encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
return sample, timestep, encoder_hidden_states
if args.type == 'sdxl':
sample = torch.randn(2, 4, 64, 64).half().cuda()
timestep = torch.rand(1).half().cuda() * 999
encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
text_embeds = torch.randn(1, 77, 2048).half().cuda()
return sample, timestep, encoder_hidden_states, text_embeds
def load_model():
log.info(f'versions: torch={torch.__version__} diffusers={diffusers.__version__}')
diffusers_load_config = {
"low_cpu_mem_usage": True,
"torch_dtype": torch.float16,
"safety_checker": None,
"requires_safety_checker": False,
"load_safety_checker": False,
"load_connected_pipeline": True,
"use_safetensors": True,
}
pipeline = diffusers.StableDiffusionPipeline if args.type == 'sd15' else diffusers.StableDiffusionXLPipeline
global pipe # pylint: disable=global-statement
t0 = time.time()
pipe = pipeline.from_single_file(args.model, **diffusers_load_config).to('cuda')
size = os.path.getsize(args.model)
log.info(f'load: model={args.model} type={args.type} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb')
def load_trace(fn: str):
@dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = pipe.unet.in_channels
self.device = pipe.unet.device
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)
t0 = time.time()
unet_traced = torch.jit.load(fn)
pipe.unet = TracedUNet()
size = os.path.getsize(fn)
log.info(f'load: optimized={fn} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb')
def trace_model():
log.info(f'tracing model: {args.model}')
torch.set_grad_enabled(False)
unet = pipe.unet
unet.eval()
# unet.to(memory_format=torch.channels_last) # use channels_last memory format
unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
# warmup
t0 = time.time()
for _ in range(n_warmup):
with torch.inference_mode():
inputs = generate_inputs()
_output = unet(*inputs)
log.info(f'warmup: time={time.time() - t0:.3f}s passes={n_warmup}')
# trace
t0 = time.time()
unet_traced = torch.jit.trace(unet, inputs, check_trace=True)
unet_traced.eval()
log.info(f'trace: time={time.time() - t0:.3f}s')
# optimize graph
t0 = time.time()
for _ in range(n_traces):
with torch.inference_mode():
inputs = generate_inputs()
_output = unet_traced(*inputs)
log.info(f'optimize: time={time.time() - t0:.3f}s passes={n_traces}')
# save the model
if args.save:
t0 = time.time()
basename, _ext = os.path.splitext(args.model)
fn = f"{basename}.pt"
unet_traced.save(fn)
size = os.path.getsize(fn)
log.info(f'save: optimized={fn} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb')
return fn
pipe.unet = unet_traced
return None
def benchmark_model(msg: str):
with torch.inference_mode():
inputs = generate_inputs()
torch.cuda.synchronize()
for n in range(n_runs):
if n > n_runs / 10:
t0 = time.time()
_output = pipe.unet(*inputs)
torch.cuda.synchronize()
t1 = time.time()
log.info(f"benchmark unet: {t1 - t0:.3f}s passes={n_runs} type={msg}")
return t1 - t0
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = 'SD.Next')
parser.add_argument('--model', type=str, default='', required=True, help='model path')
parser.add_argument('--type', type=str, default='sd15', choices=['sd15', 'sdxl'], required=False, help='model type, default: %(default)s')
parser.add_argument('--benchmark', default = False, action='store_true', help = "run benchmarks, default: %(default)s")
parser.add_argument('--trace', default = True, action='store_true', help = "run jit tracing, default: %(default)s")
parser.add_argument('--save', default = False, action='store_true', help = "save optimized unet, default: %(default)s")
args = parser.parse_args()
setup_logging()
log.info('sdnext model jit tracing')
if not os.path.isfile(args.model):
log.error(f"invalid model path: {args.model}")
exit(1)
load_model()
if args.benchmark:
time0 = benchmark_model('original')
unet_saved = trace_model()
if unet_saved is not None:
load_trace(unet_saved)
if args.benchmark:
time1 = benchmark_model('traced')
log.info(f'benchmark speedup: {100 * (time0 - time1) / time0:.3f}%')