| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import types |
| | from pathlib import Path |
| |
|
| | import tensorrt as trt |
| | import torch |
| | from cache_diffusion.cachify import CACHED_PIPE, get_model |
| | from cuda import cudart |
| | from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel |
| | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
| | from trt_pipeline.config import ONNX_CONFIG |
| | from trt_pipeline.models.sd3 import sd3_forward |
| | from trt_pipeline.models.sdxl import ( |
| | cachecrossattnupblock2d_forward, |
| | cacheunet_forward, |
| | cacheupblock2d_forward, |
| | ) |
| | from polygraphy.backend.trt import ( |
| | CreateConfig, |
| | Profile, |
| | engine_from_network, |
| | network_from_onnx_path, |
| | save_engine, |
| | ) |
| | from torch.onnx import export as onnx_export |
| |
|
| | from .utils import Engine |
| |
|
| |
|
| | def replace_new_forward(backbone): |
| | if backbone.__class__ == UNet2DConditionModel: |
| | backbone.forward = types.MethodType(cacheunet_forward, backbone) |
| | for upsample_block in backbone.up_blocks: |
| | if ( |
| | hasattr(upsample_block, "has_cross_attention") |
| | and upsample_block.has_cross_attention |
| | ): |
| | upsample_block.forward = types.MethodType( |
| | cachecrossattnupblock2d_forward, upsample_block |
| | ) |
| | else: |
| | upsample_block.forward = types.MethodType(cacheupblock2d_forward, upsample_block) |
| | elif backbone.__class__ == SD3Transformer2DModel: |
| | backbone.forward = types.MethodType(sd3_forward, backbone) |
| |
|
| |
|
| | def get_input_info(dummy_dict, info: str = None, batch_size: int = 1): |
| | return_val = [] if info == "profile_shapes" or info == "input_names" else {} |
| |
|
| | def collect_leaf_keys(d): |
| | for key, value in d.items(): |
| | if isinstance(value, dict): |
| | collect_leaf_keys(value) |
| | else: |
| | value = (value[0] * batch_size,) + value[1:] |
| | if info == "profile_shapes": |
| | return_val.append((key, value)) |
| | elif info == "profile_shapes_dict": |
| | return_val[key] = value |
| | elif info == "dummy_input": |
| | return_val[key] = torch.ones(value).half().cuda() |
| | elif info == "input_names": |
| | return_val.append(key) |
| |
|
| | collect_leaf_keys(dummy_dict) |
| | return return_val |
| |
|
| |
|
| | def get_total_device_memory(backbone): |
| | max_device_memory = 0 |
| | for _, engine in backbone.engines.items(): |
| | max_device_memory = max(max_device_memory, engine.engine.device_memory_size) |
| | return max_device_memory |
| |
|
| |
|
| | def load_engines(backbone, engine_path: Path, batch_size: int = 1): |
| | backbone.engines = {} |
| | for f in engine_path.iterdir(): |
| | if f.is_file(): |
| | eng = Engine() |
| | eng.load(str(f)) |
| | backbone.engines[f"{f.stem}"] = eng |
| | _, shared_device_memory = cudart.cudaMalloc(get_total_device_memory(backbone)) |
| | for engine in backbone.engines.values(): |
| | engine.activate(shared_device_memory) |
| | backbone.cuda_stream = cudart.cudaStreamCreate()[1] |
| | for block_name in backbone.engines.keys(): |
| | backbone.engines[block_name].allocate_buffers( |
| | shape_dict=get_input_info( |
| | ONNX_CONFIG[backbone.__class__][block_name]["dummy_input"], |
| | "profile_shapes_dict", |
| | batch_size, |
| | ), |
| | device=backbone.device, |
| | batch_size=batch_size, |
| | ) |
| | |
| |
|
| |
|
| | def warm_up(backbone, batch_size: int = 1): |
| | print("Warming-up TensorRT engines...") |
| | for name, engine in backbone.engines.items(): |
| | dummy_input = get_input_info( |
| | ONNX_CONFIG[backbone.__class__][name]["dummy_input"], "dummy_input", batch_size |
| | ) |
| | _ = engine(dummy_input, backbone.cuda_stream) |
| |
|
| |
|
| | def teardown(pipe): |
| | backbone = get_model(pipe) |
| | for engine in backbone.engines.values(): |
| | del engine |
| |
|
| | cudart.cudaStreamDestroy(backbone.cuda_stream) |
| | del backbone.cuda_stream |
| |
|
| |
|
| | def load_unet_trt(unet, engine_path: Path, batch_size: int = 1): |
| | backbone = unet |
| | engine_path.mkdir(parents=True, exist_ok=True) |
| | replace_new_forward(backbone) |
| | load_engines(backbone, engine_path, batch_size) |
| | warm_up(backbone, batch_size) |
| | backbone.use_trt_infer = True |
| |
|