| | {% if add_header %} |
| | |
| |
|
| | import ctypes |
| | from collections import OrderedDict |
| | from pathlib import Path |
| | from typing import List |
| |
|
| | import numpy as np |
| | import tensorrt as trt |
| |
|
| | from tensorrt_llm._common import default_trtnet |
| | from tensorrt_llm._utils import str_dtype_to_trt |
| | from tensorrt_llm.functional import Tensor, _create_tensor |
| | from tensorrt_llm.module import Module |
| |
|
| | TRT_LLM_PLUGIN_NAMESPACE = 'tensorrt_llm' |
| |
|
| | def _load_triton_plugin_lib(): |
| | triton_plugin_dir = Path(__file__).parent.absolute() |
| | plugin_lib = "[[ plugin_lib_path ]]" |
| | handle = ctypes.CDLL(plugin_lib, mode=ctypes.RTLD_GLOBAL) |
| | if handle is None: |
| | raise ImportError('TensorRT-LLM Triton Plugin is unavailable') |
| | handle.initLibNvInferPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] |
| | handle.initLibNvInferPlugins.restype = ctypes.c_bool |
| | assert handle.initLibNvInferPlugins( |
| | None, TRT_LLM_PLUGIN_NAMESPACE.encode('utf-8')) |
| |
|
| | _load_triton_plugin_lib() |
| |
|
| | |
| | {% endif %} |
| |
|
| | def [[ kernel_name ]]([[ arg_list ]]): |
| | ''' |
| | Inputs: |
| | {% for arg in metadata.get_params() -%} |
| | - [[arg.name]]: [[arg.dtype.dtype.to('np')]] |
| | {% endfor %} |
| | {% for arg in metadata.get_inputs() -%} |
| | - [[arg.name]]: {% if arg.is_tensor %}tensor<{%endif%}[[arg.dtype.dtype.to('np')]]> |
| | {% endfor %} |
| | Outputs: |
| | {% for arg in metadata.get_outputs() -%} |
| | - [[arg.name]]: {% if arg.is_tensor %}tensor<{%endif%}[[arg.dtype.dtype.to('np')]]> |
| | {% endfor -%} |
| | ''' |
| | plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| | '[[ plugin_name ]]', '[[ kernel_version ]]', TRT_LLM_PLUGIN_NAMESPACE) |
| | assert plg_creator is not None |
| |
|
| | pfc = trt.PluginFieldCollection([ |
| | {% for arg in params -%} |
| | { |
| | trt.PluginField("[[arg.name]]", np.array([ [[ arg.name ]] ], np.[[ arg.dtype.dtype.to('np') ]]), |
| | trt.PluginFieldType.[[ arg.dtype.dtype.to('trt_plugin_py') ]]), |
| | {% endfor %} |
| | ]) |
| |
|
| | plugin = plg_creator.create_plugin("[[ plugin_name ]]", pfc) |
| |
|
| | plug_inputs = [ [[ input_list ]] ] |
| | layer = default_trtnet().add_plugin_v2(plug_inputs, plugin) |
| |
|
| | return [ |
| | {% for id in range(num_outputs) %} |
| | _create_tensor(layer.get_output([[ id ]]), layer), |
| | {% endfor %} |
| | ] |
| |
|