|
|
import torch |
|
|
import os |
|
|
|
|
|
import comfy.model_base |
|
|
import comfy.model_management |
|
|
import comfy.model_patcher |
|
|
import comfy.supported_models |
|
|
import folder_paths |
|
|
|
|
|
if "tensorrt" in folder_paths.folder_names_and_paths: |
|
|
folder_paths.folder_names_and_paths["tensorrt"][0].append( |
|
|
os.path.join(folder_paths.models_dir, "tensorrt")) |
|
|
folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") |
|
|
else: |
|
|
folder_paths.folder_names_and_paths["tensorrt"] = ( |
|
|
[os.path.join(folder_paths.models_dir, "tensorrt")], {".engine"}) |
|
|
|
|
|
import tensorrt as trt |
|
|
|
|
|
trt.init_libnvinfer_plugins(None, "") |
|
|
|
|
|
logger = trt.Logger(trt.Logger.INFO) |
|
|
runtime = trt.Runtime(logger) |
|
|
|
|
|
|
|
|
def trt_datatype_to_torch(datatype): |
|
|
|
|
|
if datatype in (getattr(trt, "float16", None), getattr(trt.DataType, "HALF", None)): |
|
|
return torch.float16 |
|
|
if datatype in (getattr(trt, "float32", None), getattr(trt.DataType, "FLOAT", None)): |
|
|
return torch.float32 |
|
|
if hasattr(trt, "bfloat16") and datatype in ( |
|
|
getattr(trt, "bfloat16", None), |
|
|
getattr(trt.DataType, "BF16", None), |
|
|
): |
|
|
return torch.bfloat16 |
|
|
if datatype in (getattr(trt, "int32", None), getattr(trt.DataType, "INT32", None)): |
|
|
return torch.int32 |
|
|
|
|
|
return torch.float32 |
|
|
|
|
|
|
|
|
class TrTUnet: |
|
|
def __init__(self, engine_path): |
|
|
with open(engine_path, "rb") as f: |
|
|
self.engine = runtime.deserialize_cuda_engine(f.read()) |
|
|
self.context = self.engine.create_execution_context() |
|
|
|
|
|
|
|
|
self.device = comfy.model_management.get_torch_device() |
|
|
self.default_dtype = torch.float16 |
|
|
|
|
|
def _trt_dtype_to_torch(self, trt_dtype): |
|
|
dt = trt_datatype_to_torch(trt_dtype) |
|
|
return dt if dt is not None else self.default_dtype |
|
|
|
|
|
def __call__(self, x, timesteps, context, y=None, control=None, transformer_options=None, **kwargs): |
|
|
""" |
|
|
x : [B, C, H, W] |
|
|
timesteps : [B] |
|
|
context : [B, N, D] |
|
|
y : [B, y_dim] (optional, SDXL etc.) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_inputs = { |
|
|
"x": x, |
|
|
"timesteps": timesteps, |
|
|
"context": context, |
|
|
} |
|
|
if y is not None: |
|
|
model_inputs["y"] = y |
|
|
|
|
|
|
|
|
|
|
|
tensor_names = [self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)] |
|
|
input_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT] |
|
|
output_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT] |
|
|
|
|
|
|
|
|
for name in input_names: |
|
|
if name in model_inputs: |
|
|
continue |
|
|
if name in kwargs: |
|
|
model_inputs[name] = kwargs[name] |
|
|
|
|
|
if len(model_inputs) != len(input_names): |
|
|
missing = [n for n in input_names if n not in model_inputs] |
|
|
raise RuntimeError( |
|
|
f"TensorRT UNet: missing required inputs for engine: {missing} " |
|
|
f"(have {list(model_inputs.keys())})" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for name in input_names: |
|
|
t = model_inputs[name] |
|
|
|
|
|
|
|
|
if t.device != self.device: |
|
|
t = t.to(self.device) |
|
|
|
|
|
|
|
|
trt_dtype = self.engine.get_tensor_dtype(name) |
|
|
torch_dtype = self._trt_dtype_to_torch(trt_dtype) |
|
|
if t.dtype != torch_dtype: |
|
|
t = t.to(dtype=torch_dtype) |
|
|
|
|
|
|
|
|
model_inputs[name] = t |
|
|
|
|
|
|
|
|
self.context.set_input_shape(name, tuple(t.shape)) |
|
|
self.context.set_tensor_address(name, int(t.data_ptr())) |
|
|
|
|
|
|
|
|
missing = self.context.infer_shapes() |
|
|
if missing: |
|
|
raise RuntimeError(f"TensorRT shape inference failed, unresolved tensors: {missing}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = {} |
|
|
for name in output_names: |
|
|
out_dims = self.context.get_tensor_shape(name) |
|
|
out_shape = tuple(int(d) for d in out_dims) |
|
|
|
|
|
trt_dtype = self.engine.get_tensor_dtype(name) |
|
|
torch_dtype = self._trt_dtype_to_torch(trt_dtype) |
|
|
|
|
|
out_tensor = torch.empty(out_shape, device=self.device, dtype=torch_dtype) |
|
|
self.context.set_tensor_address(name, int(out_tensor.data_ptr())) |
|
|
outputs[name] = out_tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stream = torch.cuda.current_stream(self.device) |
|
|
self.context.execute_async_v3(stream_handle=stream.cuda_stream) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_list = [outputs[name] for name in output_names] |
|
|
return out_list[0] if len(out_list) == 1 else tuple(out_list) |
|
|
|
|
|
def load_state_dict(self, sd, strict=False): |
|
|
pass |
|
|
|
|
|
def state_dict(self): |
|
|
return {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorRTLoader: |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ), |
|
|
"model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow", "flux_dev", "flux_schnell"], ), |
|
|
}} |
|
|
RETURN_TYPES = ("MODEL",) |
|
|
FUNCTION = "load_unet" |
|
|
CATEGORY = "TensorRT" |
|
|
|
|
|
def load_unet(self, unet_name, model_type): |
|
|
unet_path = folder_paths.get_full_path("tensorrt", unet_name) |
|
|
if not os.path.isfile(unet_path): |
|
|
raise FileNotFoundError(f"File {unet_path} does not exist") |
|
|
unet = TrTUnet(unet_path) |
|
|
if model_type == "sdxl_base": |
|
|
conf = comfy.supported_models.SDXL({"adm_in_channels": 2816}) |
|
|
conf.unet_config["disable_unet_model_creation"] = True |
|
|
model = comfy.model_base.SDXL(conf) |
|
|
elif model_type == "sdxl_refiner": |
|
|
conf = comfy.supported_models.SDXLRefiner( |
|
|
{"adm_in_channels": 2560}) |
|
|
conf.unet_config["disable_unet_model_creation"] = True |
|
|
model = comfy.model_base.SDXLRefiner(conf) |
|
|
elif model_type == "sd1.x": |
|
|
conf = comfy.supported_models.SD15({}) |
|
|
conf.unet_config["disable_unet_model_creation"] = True |
|
|
model = comfy.model_base.BaseModel(conf) |
|
|
elif model_type == "sd2.x-768v": |
|
|
conf = comfy.supported_models.SD20({}) |
|
|
conf.unet_config["disable_unet_model_creation"] = True |
|
|
model = comfy.model_base.BaseModel(conf, model_type=comfy.model_base.ModelType.V_PREDICTION) |
|
|
elif model_type == "svd": |
|
|
conf = comfy.supported_models.SVD_img2vid({}) |
|
|
conf.unet_config["disable_unet_model_creation"] = True |
|
|
model = conf.get_model({}) |
|
|
elif model_type == "sd3": |
|
|
conf = comfy.supported_models.SD3({}) |
|
|
conf.unet_config["disable_unet_model_creation"] = True |
|
|
model = conf.get_model({}) |
|
|
elif model_type == "auraflow": |
|
|
conf = comfy.supported_models.AuraFlow({}) |
|
|
conf.unet_config["disable_unet_model_creation"] = True |
|
|
model = conf.get_model({}) |
|
|
elif model_type == "flux_dev": |
|
|
conf = comfy.supported_models.Flux({}) |
|
|
conf.unet_config["disable_unet_model_creation"] = True |
|
|
model = conf.get_model({}) |
|
|
unet.dtype = torch.bfloat16 |
|
|
elif model_type == "flux_schnell": |
|
|
conf = comfy.supported_models.FluxSchnell({}) |
|
|
conf.unet_config["disable_unet_model_creation"] = True |
|
|
model = conf.get_model({}) |
|
|
unet.dtype = torch.bfloat16 |
|
|
model.diffusion_model = unet |
|
|
model.memory_required = lambda *args, **kwargs: 0 |
|
|
|
|
|
return (comfy.model_patcher.ModelPatcher(model, |
|
|
load_device=comfy.model_management.get_torch_device(), |
|
|
offload_device=comfy.model_management.unet_offload_device()),) |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
"TensorRTLoader": TensorRTLoader, |
|
|
} |
|
|
|