MyCustomNodes / tensorrt_loader.py
saliacoel's picture
Upload 2 files
e24ff1f verified
import torch
import os
import comfy.model_base
import comfy.model_management
import comfy.model_patcher
import comfy.supported_models
import folder_paths
if "tensorrt" in folder_paths.folder_names_and_paths:
folder_paths.folder_names_and_paths["tensorrt"][0].append(
os.path.join(folder_paths.models_dir, "tensorrt"))
folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine")
else:
folder_paths.folder_names_and_paths["tensorrt"] = (
[os.path.join(folder_paths.models_dir, "tensorrt")], {".engine"})
import tensorrt as trt
trt.init_libnvinfer_plugins(None, "")
logger = trt.Logger(trt.Logger.INFO)
runtime = trt.Runtime(logger)
def trt_datatype_to_torch(datatype):
# Works for TRT 8/9/10
if datatype in (getattr(trt, "float16", None), getattr(trt.DataType, "HALF", None)):
return torch.float16
if datatype in (getattr(trt, "float32", None), getattr(trt.DataType, "FLOAT", None)):
return torch.float32
if hasattr(trt, "bfloat16") and datatype in (
getattr(trt, "bfloat16", None),
getattr(trt.DataType, "BF16", None),
):
return torch.bfloat16
if datatype in (getattr(trt, "int32", None), getattr(trt.DataType, "INT32", None)):
return torch.int32
# Fallback – shouldn't normally hit this for UNets
return torch.float32
class TrTUnet:
def __init__(self, engine_path):
with open(engine_path, "rb") as f:
self.engine = runtime.deserialize_cuda_engine(f.read())
self.context = self.engine.create_execution_context()
# Default torch device / dtype for allocations
self.device = comfy.model_management.get_torch_device()
self.default_dtype = torch.float16 # fallback if something unknown shows up
def _trt_dtype_to_torch(self, trt_dtype):
dt = trt_datatype_to_torch(trt_dtype)
return dt if dt is not None else self.default_dtype
def __call__(self, x, timesteps, context, y=None, control=None, transformer_options=None, **kwargs):
"""
x : [B, C, H, W]
timesteps : [B]
context : [B, N, D]
y : [B, y_dim] (optional, SDXL etc.)
"""
# -----------------------------
# 1. Build dict of actual inputs
# -----------------------------
model_inputs = {
"x": x,
"timesteps": timesteps,
"context": context,
}
if y is not None:
model_inputs["y"] = y
# If your engine has extra inputs (e.g. 'guidance' for Flux),
# they must either come from kwargs or be absent from the engine.
tensor_names = [self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)]
input_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT]
output_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT]
# Fill missing inputs from kwargs if present
for name in input_names:
if name in model_inputs:
continue
if name in kwargs:
model_inputs[name] = kwargs[name]
if len(model_inputs) != len(input_names):
missing = [n for n in input_names if n not in model_inputs]
raise RuntimeError(
f"TensorRT UNet: missing required inputs for engine: {missing} "
f"(have {list(model_inputs.keys())})"
)
# -----------------------------
# 2. Convert each input to engine dtype + bind it
# -----------------------------
for name in input_names:
t = model_inputs[name]
# Move to correct device
if t.device != self.device:
t = t.to(self.device)
# Match TensorRT's expected dtype for this tensor
trt_dtype = self.engine.get_tensor_dtype(name)
torch_dtype = self._trt_dtype_to_torch(trt_dtype)
if t.dtype != torch_dtype:
t = t.to(dtype=torch_dtype)
# Update back (so later code sees the converted tensor if needed)
model_inputs[name] = t
# Set runtime shape and bind memory
self.context.set_input_shape(name, tuple(t.shape))
self.context.set_tensor_address(name, int(t.data_ptr()))
# Make sure all shapes are resolved
missing = self.context.infer_shapes()
if missing:
raise RuntimeError(f"TensorRT shape inference failed, unresolved tensors: {missing}")
# -----------------------------
# 3. Allocate & bind outputs
# -----------------------------
outputs = {}
for name in output_names:
out_dims = self.context.get_tensor_shape(name) # trt.Dims
out_shape = tuple(int(d) for d in out_dims)
trt_dtype = self.engine.get_tensor_dtype(name)
torch_dtype = self._trt_dtype_to_torch(trt_dtype)
out_tensor = torch.empty(out_shape, device=self.device, dtype=torch_dtype)
self.context.set_tensor_address(name, int(out_tensor.data_ptr()))
outputs[name] = out_tensor
# -----------------------------
# 4. Execute on the current torch CUDA stream
# -----------------------------
stream = torch.cuda.current_stream(self.device)
self.context.execute_async_v3(stream_handle=stream.cuda_stream)
# No need to sync explicitly; ComfyUI uses the same default stream.
# Return outputs in a stable order
out_list = [outputs[name] for name in output_names]
return out_list[0] if len(out_list) == 1 else tuple(out_list)
def load_state_dict(self, sd, strict=False):
pass
def state_dict(self):
return {}
class TensorRTLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ),
"model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow", "flux_dev", "flux_schnell"], ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
CATEGORY = "TensorRT"
def load_unet(self, unet_name, model_type):
unet_path = folder_paths.get_full_path("tensorrt", unet_name)
if not os.path.isfile(unet_path):
raise FileNotFoundError(f"File {unet_path} does not exist")
unet = TrTUnet(unet_path)
if model_type == "sdxl_base":
conf = comfy.supported_models.SDXL({"adm_in_channels": 2816})
conf.unet_config["disable_unet_model_creation"] = True
model = comfy.model_base.SDXL(conf)
elif model_type == "sdxl_refiner":
conf = comfy.supported_models.SDXLRefiner(
{"adm_in_channels": 2560})
conf.unet_config["disable_unet_model_creation"] = True
model = comfy.model_base.SDXLRefiner(conf)
elif model_type == "sd1.x":
conf = comfy.supported_models.SD15({})
conf.unet_config["disable_unet_model_creation"] = True
model = comfy.model_base.BaseModel(conf)
elif model_type == "sd2.x-768v":
conf = comfy.supported_models.SD20({})
conf.unet_config["disable_unet_model_creation"] = True
model = comfy.model_base.BaseModel(conf, model_type=comfy.model_base.ModelType.V_PREDICTION)
elif model_type == "svd":
conf = comfy.supported_models.SVD_img2vid({})
conf.unet_config["disable_unet_model_creation"] = True
model = conf.get_model({})
elif model_type == "sd3":
conf = comfy.supported_models.SD3({})
conf.unet_config["disable_unet_model_creation"] = True
model = conf.get_model({})
elif model_type == "auraflow":
conf = comfy.supported_models.AuraFlow({})
conf.unet_config["disable_unet_model_creation"] = True
model = conf.get_model({})
elif model_type == "flux_dev":
conf = comfy.supported_models.Flux({})
conf.unet_config["disable_unet_model_creation"] = True
model = conf.get_model({})
unet.dtype = torch.bfloat16 #TODO: autodetect
elif model_type == "flux_schnell":
conf = comfy.supported_models.FluxSchnell({})
conf.unet_config["disable_unet_model_creation"] = True
model = conf.get_model({})
unet.dtype = torch.bfloat16 #TODO: autodetect
model.diffusion_model = unet
model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting
return (comfy.model_patcher.ModelPatcher(model,
load_device=comfy.model_management.get_torch_device(),
offload_device=comfy.model_management.unet_offload_device()),)
NODE_CLASS_MAPPINGS = {
"TensorRTLoader": TensorRTLoader,
}