| | from dataclasses import dataclass |
| | from typing import Dict, List, Optional |
| |
|
| | import numpy as np |
| | import tensorrt as trt |
| | import torch |
| |
|
| | from tensorrt_llm._utils import trt_dtype_to_str, trt_dtype_to_torch |
| | from tensorrt_llm.logger import logger |
| | from tensorrt_llm.network import Network, get_plugin_info, set_plugin_info |
| | from tensorrt_llm.plugin.plugin import PluginConfig |
| | from tensorrt_llm.runtime.session import Session |
| |
|
| | from .utils import (current_flags, get_builder_flags, get_sorted_layer_ids, |
| | get_strongly_typed, get_trt_network, set_trt_network, |
| | to_base_class_layer, to_subclass_layer) |
| |
|
| |
|
| | class Tensor: |
| |
|
| | def __init__(self, graph: "PipelineGraph"): |
| | self._graph = graph |
| | self._trt = None |
| | self._shape = None |
| | self._max_shape = None |
| | self._value = None |
| | self.producer: Layer = None |
| | self.output_index = None |
| | self.consumers = [] |
| | self.graph_input_index = -1 |
| | self.graph_output_index = -1 |
| | self.attrs = {} |
| |
|
| | @staticmethod |
| | def from_trt(graph: "PipelineGraph", trt_tensor: trt.ITensor): |
| | tensor = Tensor(graph) |
| | tensor._trt = trt_tensor |
| | return tensor |
| |
|
| | def as_trt(self) -> trt.ITensor: |
| | return self._trt |
| |
|
| | def copy(self) -> "Tensor": |
| | tensor = Tensor(self._graph) |
| | tensor._trt = self._trt |
| | tensor._shape = self._shape |
| | tensor._max_shape = self._max_shape |
| | tensor._value = self._value |
| | tensor.producer = self.producer |
| | tensor.output_index = self.output_index |
| | tensor.consumers = [*self.consumers] |
| | tensor.graph_input_index = self.graph_input_index |
| | tensor.graph_output_index = self.graph_output_index |
| | tensor.attrs = self.attrs.copy() |
| | return tensor |
| |
|
| | @property |
| | def graph(self) -> "PipelineGraph": |
| | return self._graph |
| |
|
| | @property |
| | def name(self) -> str: |
| | return self._trt.name |
| |
|
| | @name.setter |
| | def name(self, name: str): |
| | old_name = self._trt.name |
| | if name != old_name: |
| | self._trt.name = name |
| | self.graph._tensors[name] = self |
| | del self.graph._tensors[old_name] |
| | if self.is_graph_input: |
| | self.graph._inputs[name] = self |
| | del self.graph._inputs[old_name] |
| | elif self.is_graph_output: |
| | self.graph._outputs[name] = self |
| | del self.graph._outputs[old_name] |
| |
|
| | @property |
| | def shape(self): |
| | return self._shape |
| |
|
| | @property |
| | def max_shape(self): |
| | return self._max_shape |
| |
|
| | @property |
| | def raw_shape(self): |
| | assert isinstance(self._trt, trt.ITensor) |
| | return self._trt.shape |
| |
|
| | @shape.setter |
| | def shape(self, shape): |
| | self._shape = shape |
| |
|
| | @max_shape.setter |
| | def max_shape(self, max_shape): |
| | self._max_shape = max_shape |
| |
|
| | @raw_shape.setter |
| | def raw_shape(self, raw_shape): |
| | assert isinstance(self._trt, trt.ITensor) |
| | self._trt.shape = raw_shape |
| |
|
| | @property |
| | def value(self): |
| | return self._value |
| |
|
| | @value.setter |
| | def value(self, value): |
| | self._value = value |
| |
|
| | @property |
| | def dtype(self): |
| | return self._trt.dtype |
| |
|
| | @property |
| | def broadcast_across_batch(self): |
| | return self._trt.broadcast_across_batch |
| |
|
| | @property |
| | def dtype_size(self): |
| | return self.dtype.itemsize |
| |
|
| | @property |
| | def dtype_str(self): |
| | return trt_dtype_to_str(self.dtype) |
| |
|
| | @property |
| | def dtype_str_size(self): |
| | return [trt_dtype_to_str(self.dtype), self.dtype.itemsize] |
| |
|
| | @property |
| | def is_graph_input(self) -> bool: |
| | return self.graph_input_index != -1 |
| |
|
| | @property |
| | def is_graph_output(self) -> bool: |
| | return self.graph_output_index != -1 |
| |
|
| | @property |
| | def is_graph_io(self) -> bool: |
| | return self.is_graph_input or self.is_graph_output |
| |
|
| |
|
| | class Layer: |
| |
|
| | def __init__(self, graph): |
| | self._graph = graph |
| | self._trt = None |
| | self._index = None |
| | self._inputs = [] |
| | self._outputs = [] |
| | self._is_shape_io = False |
| | self.attrs = {} |
| |
|
| | @staticmethod |
| | def from_trt(graph, trt_layer, index): |
| | layer = Layer(graph) |
| | layer._trt = trt_layer |
| | layer._index = index |
| | for i in range(trt_layer.num_inputs): |
| | input = trt_layer.get_input(i) |
| | if input is not None: |
| | layer._inputs.append(graph.get_tensor(input.name)) |
| | layer._inputs[i].consumers.append((layer, i)) |
| | else: |
| | layer._inputs.append(None) |
| | for i in range(trt_layer.num_outputs): |
| | output = trt_layer.get_output(i) |
| | layer._outputs.append(graph.get_tensor(output.name)) |
| | layer._outputs[i].producer = layer |
| | layer._outputs[i].output_index = i |
| | set_trt_network(trt_layer, graph.as_trt()) |
| | return layer |
| |
|
| | def as_trt(self) -> trt.ILayer: |
| | return self._trt |
| |
|
| | @property |
| | def graph(self) -> "PipelineGraph": |
| | return self._graph |
| |
|
| | @property |
| | def name(self) -> str: |
| | return self._trt.name |
| |
|
| | @name.setter |
| | def name(self, name: str): |
| | old_name = self._trt.name |
| | if name != old_name: |
| | self._trt.name = name |
| | self.graph._layers[name] = self |
| | del self.graph._layers[old_name] |
| |
|
| | @property |
| | def type(self) -> trt.LayerType: |
| | return self._trt.type |
| |
|
| | @property |
| | def index(self) -> int: |
| | return self._index |
| |
|
| | @property |
| | def inputs(self) -> List[Tensor]: |
| | return self._inputs |
| |
|
| | @property |
| | def outputs(self) -> List[Tensor]: |
| | return self._outputs |
| |
|
| | def get_input(self, index: int) -> Tensor: |
| | return self._inputs[index] |
| |
|
| | def get_output(self, index: int) -> Tensor: |
| | return self._outputs[index] |
| |
|
| | @property |
| | def num_inputs(self) -> int: |
| | return self._trt.num_inputs |
| |
|
| | @property |
| | def num_outputs(self) -> int: |
| | return self._trt.num_outputs |
| |
|
| | @property |
| | def is_shape_io(self) -> bool: |
| | return self._is_shape_io |
| |
|
| | def to_subclass(self): |
| | to_subclass_layer(self._trt) |
| |
|
| | def to_base_class(self): |
| | to_base_class_layer(self._trt) |
| |
|
| | def assign_shapes(self, shapes, values): |
| | for output in self.outputs: |
| | output.shape = shapes[output.name] |
| | output.value = values.get(output.name) |
| |
|
| |
|
| | @dataclass |
| | class GraphRunner: |
| | session: Session |
| | inputs: Dict[str, torch.Tensor] |
| | outputs: Dict[str, torch.Tensor] |
| | stream: torch.Stream |
| |
|
| | def run(self): |
| | cuda_stream = self.stream.cuda_stream |
| | assert self.session.run(self.inputs, self.outputs, cuda_stream) |
| | self.stream.synchronize() |
| | return self.outputs |
| |
|
| |
|
| | class PipelineGraph: |
| |
|
| | def __init__(self): |
| | self._trt = None |
| | self._inputs: Dict[str, Tensor] = {} |
| | self._outputs: Dict[str, Tensor] = {} |
| | self._layers: Dict[str, Layer] = {} |
| | self._tensors: Dict[str, Tensor] = {} |
| | self._io_buffer_mapping = {} |
| | self._unfilled_weights = {} |
| | self._auto_parallel_config = None |
| | self._plugin_config: PluginConfig = None |
| |
|
| | @staticmethod |
| | def create_graph(): |
| | graph = PipelineGraph() |
| | trt_builder = trt.Builder(logger.trt_logger) |
| | explicit_batch_flag = 0 |
| | |
| | if "EXPLICIT_BATCH" in trt.NetworkDefinitionCreationFlag.__members__.keys( |
| | ): |
| | explicit_batch_flag = 1 << int( |
| | trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) |
| | if get_strongly_typed(): |
| | network = trt_builder.create_network( |
| | explicit_batch_flag |
| | | (1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))) |
| | else: |
| | network = trt_builder.create_network(explicit_batch_flag) |
| | graph._trt = network |
| | return graph |
| |
|
| | def _register_unfilled_weights(self, layer_name, weights, values): |
| | self._unfilled_weights[layer_name] = (weights, values) |
| |
|
| | def _add_tensor(self, tensor, old_tensor, prefix): |
| | if prefix is not None: |
| | tensor.name = prefix + old_tensor.name |
| | else: |
| | tensor.name = old_tensor.name |
| | tensor.location = old_tensor.location |
| | if old_tensor.dynamic_range is not None: |
| | tensor.dynamic_range = old_tensor.dynamic_range |
| | if tensor.is_network_input: |
| | tensor.shape = old_tensor.shape |
| | for i in range(len(old_tensor.shape)): |
| | name = old_tensor.get_dimension_name(i) |
| | if name is not None: |
| | tensor.set_dimension_name(i, name) |
| | return self._register_tensor(tensor) |
| |
|
| | def _register_tensor(self, tensor): |
| | wrapped_tensor = Tensor.from_trt(self, tensor) |
| | assert tensor.name not in self._tensors |
| | self._tensors[tensor.name] = wrapped_tensor |
| | return wrapped_tensor |
| |
|
| | def add_input(self, tensor, prefix=None): |
| | tensor_name = tensor.name |
| | if prefix is not None: |
| | tensor_name = prefix + tensor_name |
| | input = self._trt.add_input(tensor_name, tensor.dtype, tensor.shape) |
| | new_tensor = self._add_tensor(input, tensor, prefix) |
| | new_tensor.graph_input_index = len(self._inputs) |
| | self._inputs[tensor_name] = new_tensor |
| | return new_tensor |
| |
|
| | def register_input(self, tensor, index=None): |
| | if index is None: |
| | index = self.num_inputs - 1 |
| | assert self._trt.get_input(index).name == tensor.name |
| | wrapped_input = self._register_tensor(tensor) |
| | wrapped_input.graph_input_index = index |
| | self._inputs[tensor.name] = wrapped_input |
| | return wrapped_input |
| |
|
| | def add_output(self, tensor, prefix=None): |
| | tensor_name = tensor.name |
| | if prefix is not None: |
| | tensor_name = prefix + tensor_name |
| | output = self.get_tensor(tensor_name) |
| | output.graph_output_index = len(self._outputs) |
| | trt_output = output.as_trt() |
| | self._trt.mark_output(trt_output) |
| | trt_output.dtype = tensor.dtype |
| | self._outputs[tensor_name] = output |
| | return output |
| |
|
| | def add_output_shape(self, tensor, prefix=None): |
| | tensor_name = tensor.name |
| | if prefix is not None: |
| | tensor_name = prefix + tensor_name |
| | output = self.get_tensor(tensor_name) |
| | trt_output = output.as_trt() |
| | self._trt.mark_output_for_shapes(trt_output) |
| | trt_output.dtype = tensor.dtype |
| | self._outputs[tensor_name] = output |
| | return output |
| |
|
| | def add_layer( |
| | self, |
| | layer, |
| | input_mapping=None, |
| | prefix=None, |
| | updated_attrs=None, |
| | ) -> Layer: |
| |
|
| | def get_input(i): |
| | name = layer.get_input(i).name |
| | if prefix is not None: |
| | name = prefix + name |
| | if input_mapping is not None and name in input_mapping: |
| | name = input_mapping[name] |
| | return self.get_tensor(name).as_trt() |
| |
|
| | network = self._trt |
| | layer_type = layer.type |
| | to_subclass_layer(layer) |
| | if layer_type == trt.LayerType.ACTIVATION: |
| | trt_input = get_input(0) |
| | new_layer = network.add_activation(trt_input, layer.type) |
| | new_layer.alpha = layer.alpha |
| | new_layer.beta = layer.beta |
| | elif layer_type == trt.LayerType.CONCATENATION: |
| | trt_inputs = [get_input(i) for i in range(layer.num_inputs)] |
| | new_layer = network.add_concatenation(trt_inputs) |
| | new_layer.axis = layer.axis |
| | elif layer_type == trt.LayerType.CONSTANT: |
| | new_layer = network.add_constant(layer.shape, layer.weights) |
| | elif layer_type == trt.LayerType.ELEMENTWISE: |
| | new_layer = network.add_elementwise(get_input(0), get_input(1), |
| | layer.op) |
| | elif layer_type == trt.LayerType.FILL: |
| | if layer.num_inputs >= 1 and layer.get_input(0) is not None: |
| | shape_input = get_input(0) |
| | shape = [1] |
| | else: |
| | shape_input = None |
| | shape = layer.shape |
| | new_layer = network.add_fill(shape, layer.operation, layer.to_type) |
| | if shape_input is not None: |
| | new_layer.set_input(0, shape_input) |
| | if layer.num_inputs >= 1 and layer.get_input(0) is not None: |
| | new_layer.set_input(0, get_input(0)) |
| | if layer.num_inputs >= 2 and layer.get_input(1) is not None: |
| | new_layer.set_input(1, get_input(1)) |
| | else: |
| | new_layer.alpha = layer.alpha |
| | if layer.num_inputs >= 3 and layer.get_input(2) is not None: |
| | new_layer.set_input(2, get_input(2)) |
| | else: |
| | new_layer.beta = layer.beta |
| | elif layer_type == trt.LayerType.GATHER: |
| | trt_input = get_input(0) |
| | trt_indices = get_input(1) |
| | new_layer = network.add_gather_v2(trt_input, trt_indices, |
| | layer.mode) |
| | new_layer.axis = layer.axis |
| | new_layer.num_elementwise_dims = layer.num_elementwise_dims |
| | new_layer.mode = layer.mode |
| | elif layer_type == trt.LayerType.MATRIX_MULTIPLY: |
| | new_layer = network.add_matrix_multiply(get_input(0), layer.op0, |
| | get_input(1), layer.op1) |
| | elif layer_type == trt.LayerType.REDUCE: |
| | new_layer = network.add_reduce(get_input(0), layer.op, layer.axes, |
| | layer.keep_dims) |
| | elif layer_type == trt.LayerType.SELECT: |
| | trt_condition = get_input(0) |
| | trt_then = get_input(1) |
| | trt_else = get_input(2) |
| | new_layer = network.add_select(trt_condition, trt_then, trt_else) |
| | elif layer_type == trt.LayerType.SHUFFLE: |
| | new_layer = network.add_shuffle(get_input(0)) |
| | new_layer.first_transpose = layer.first_transpose |
| | new_layer.second_transpose = layer.second_transpose |
| | new_layer.zero_is_placeholder = layer.zero_is_placeholder |
| | if layer.num_inputs >= 2: |
| | trt_reshape_dims_tensor = get_input(1) |
| | new_layer.set_input(1, trt_reshape_dims_tensor) |
| | else: |
| | new_layer.reshape_dims = layer.reshape_dims |
| | elif layer_type == trt.LayerType.SLICE: |
| | if layer.num_inputs >= 2 and layer.get_input(1) is not None: |
| | trt_start = get_input(1) |
| | start = [] |
| | else: |
| | trt_start = None |
| | start = layer.start |
| | if layer.num_inputs >= 3 and layer.get_input(2) is not None: |
| | trt_shape = get_input(2) |
| | shape = [] |
| | else: |
| | trt_shape = None |
| | shape = layer.shape |
| | if layer.num_inputs >= 4 and layer.get_input(3) is not None: |
| | trt_stride = get_input(3) |
| | stride = [] |
| | else: |
| | trt_stride = None |
| | stride = layer.stride |
| | new_layer = network.add_slice(get_input(0), start, shape, stride) |
| | new_layer.mode = layer.mode |
| | if trt_start is not None: |
| | new_layer.set_input(1, trt_start) |
| | if trt_shape is not None: |
| | new_layer.set_input(2, trt_shape) |
| | if trt_stride is not None: |
| | new_layer.set_input(3, trt_stride) |
| | elif layer_type == trt.LayerType.SOFTMAX: |
| | new_layer = network.add_softmax(get_input(0)) |
| | new_layer.axes = layer.axes |
| | elif layer_type == trt.LayerType.UNARY: |
| | new_layer = network.add_unary(get_input(0), layer.op) |
| | elif layer_type == trt.LayerType.SHAPE: |
| | new_layer = network.add_shape(get_input(0)) |
| | elif layer_type == trt.LayerType.ASSERTION: |
| | new_layer = network.add_assertion(get_input(0), layer.message) |
| | elif layer_type == trt.LayerType.CAST: |
| | new_layer = network.add_cast(get_input(0), layer.to_type) |
| | elif layer_type == trt.LayerType.NORMALIZATION: |
| | trt_input = get_input(0) |
| | trt_scale = get_input(1) |
| | trt_bias = get_input(2) |
| | new_layer = network.add_normalization(trt_input, trt_scale, |
| | trt_bias, layer.axes) |
| | new_layer.epsilon = layer.epsilon |
| | new_layer.num_groups = layer.num_groups |
| | new_layer.compute_precision = layer.compute_precision |
| | elif layer_type == trt.LayerType.IDENTITY: |
| | new_layer = network.add_identity(get_input(0)) |
| | elif layer_type == trt.LayerType.PLUGIN_V2: |
| | plugin = layer.plugin |
| | updated = False |
| | if (updated_attrs is not None |
| | and updated_attrs.get("plugin") is not None): |
| | plugin = updated_attrs["plugin"] |
| | updated = True |
| | updated_attrs = None |
| | new_layer = network.add_plugin_v2( |
| | [get_input(i) for i in range(layer.num_inputs)], |
| | plugin, |
| | ) |
| | else: |
| | raise NotImplementedError( |
| | "Unsupported layer type: {}".format(layer_type)) |
| |
|
| | if updated_attrs is not None: |
| | for attr_name, attr_value in updated_attrs.items(): |
| | setattr(new_layer, attr_name, attr_value) |
| |
|
| | to_base_class_layer(layer) |
| | to_base_class_layer(new_layer) |
| | layer_index = network.num_layers - 1 |
| | layer_name = layer.name |
| | if prefix is not None: |
| | layer_name = prefix + layer_name |
| | new_layer.name = layer_name |
| | new_layer.metadata = new_layer.name |
| | if layer.precision_is_set: |
| | new_layer.precision = layer.precision |
| | for i in range(layer.num_outputs): |
| | if layer.output_type_is_set(i): |
| | new_layer.set_output_type(i, layer.get_output_type(i)) |
| | output = new_layer.get_output(i) |
| | self._add_tensor(output, layer.get_output(i), prefix) |
| | wrapped_layer = Layer.from_trt(self, new_layer, layer_index) |
| | assert layer_name not in self._layers |
| | self._layers[layer_name] = wrapped_layer |
| | if layer_type == trt.LayerType.PLUGIN_V2: |
| | if not updated: |
| | plugin_info = get_plugin_info(get_trt_network(layer), |
| | layer.name) |
| | set_plugin_info(self.as_trt(), new_layer.name, plugin_info) |
| | return wrapped_layer |
| |
|
| | def register_layer(self, layer, index=None): |
| | if index is None: |
| | index = self.num_layers - 1 |
| | assert self._trt.get_layer(index).name == layer.name |
| | to_base_class_layer(layer) |
| | for i in range(layer.num_outputs): |
| | output = layer.get_output(i) |
| | self._register_tensor(output) |
| | wrapped_layer = Layer.from_trt(self, layer, index) |
| | assert layer.name not in self._layers |
| | self._layers[layer.name] = wrapped_layer |
| | to_subclass_layer(layer) |
| | return wrapped_layer |
| |
|
| | def get_runner( |
| | self, |
| | shapes=None, |
| | values=None, |
| | profile=None, |
| | timing_cache=None, |
| | opt_level=None, |
| | ) -> GraphRunner: |
| | shapes = shapes or {} |
| | values = values or {} |
| | inputs = {} |
| | outputs = {} |
| | for input in self.inputs: |
| | if input is not None: |
| | value = values.get(input.name) |
| | if value is None: |
| | value = input.value |
| | if value is not None: |
| | if not isinstance(value, torch.Tensor): |
| | value = torch.tensor( |
| | value, |
| | dtype=trt_dtype_to_torch(input.dtype), |
| | device='cpu', |
| | ) |
| | inputs[input.name] = value |
| | else: |
| | shape = shapes.get(input.name) |
| | if shape is None: |
| | shape = input.shape |
| | assert shape is not None |
| | inputs[input.name] = torch.empty( |
| | tuple(shape), |
| | dtype=trt_dtype_to_torch(input.dtype), |
| | device=torch.cuda.current_device(), |
| | ) |
| | if torch.is_floating_point(inputs[input.name]): |
| | inputs[input.name].normal_() |
| | |
| | for output in self.outputs: |
| | if output.as_trt().is_shape_tensor: |
| | continue |
| | if output.name in self._io_buffer_mapping: |
| | input_name = self._io_buffer_mapping[output.name] |
| | if input_name in inputs: |
| | outputs[output.name] = inputs[input_name] |
| | continue |
| | value = values.get(output.name) |
| | if value is not None and isinstance(value, torch.Tensor): |
| | outputs[output.name] = value |
| | else: |
| | shape = shapes.get(output.name) |
| | if shape is None: |
| | shape = output.shape |
| | assert shape is not None |
| | outputs[output.name] = torch.empty( |
| | tuple(shape), |
| | dtype=trt_dtype_to_torch(output.dtype), |
| | device=torch.cuda.current_device(), |
| | ) |
| | network = self.as_trt() |
| | config = network.builder.create_builder_config() |
| | if opt_level is not None: |
| | config.builder_optimization_level = opt_level |
| | config.flags = get_builder_flags() |
| | profile = profile or network.builder.create_optimization_profile() |
| | profile_index = config.add_optimization_profile(profile) |
| | if timing_cache is not None: |
| | config.set_timing_cache(timing_cache, ignore_mismatch=False) |
| | plan = network.builder.build_serialized_network(network, config) |
| | if plan is None: |
| | logger.error('Engine building failed, please check the error log.') |
| | session = Session.from_serialized_engine(plan) |
| | stream = torch.cuda.current_stream() |
| | cuda_stream = stream.cuda_stream |
| | context = session.context |
| | context.set_optimization_profile_async(profile_index, cuda_stream) |
| | runner = GraphRunner(session, inputs, outputs, stream) |
| | return runner |
| |
|
| | def run( |
| | self, |
| | shapes=None, |
| | values=None, |
| | profile=None, |
| | timing_cache=None, |
| | opt_level=None, |
| | ): |
| | return self.get_runner( |
| | shapes, |
| | values, |
| | profile, |
| | timing_cache, |
| | opt_level, |
| | ).run() |
| |
|
| | def duplicate_graph(self): |
| | graph = PipelineGraph.create_graph() |
| | network = self.as_trt() |
| | for i in range(network.num_inputs): |
| | input = network.get_input(i) |
| | graph.add_input(input) |
| | sorted_layer_ids = get_sorted_layer_ids(network) |
| | for i in sorted_layer_ids: |
| | layer = network.get_layer(i) |
| | graph.add_layer(layer) |
| | for i in range(network.num_outputs): |
| | output = network.get_output(i) |
| | if output.is_shape_tensor: |
| | graph.add_output_shape(output) |
| | else: |
| | graph.add_output(output) |
| | return graph |
| |
|
| | @staticmethod |
| | def from_trt(trt_network): |
| | graph = PipelineGraph() |
| | graph._trt = trt_network |
| |
|
| | |
| | for i in range(trt_network.num_inputs): |
| | trt_input = trt_network.get_input(i) |
| | tensor = Tensor.from_trt(graph, trt_input) |
| | tensor.graph_input_index = i |
| | graph._tensors[tensor.name] = tensor |
| | graph._inputs[tensor.name] = tensor |
| | for i in range(trt_network.num_layers): |
| | trt_layer = trt_network.get_layer(i) |
| | for i in range(trt_layer.num_outputs): |
| | trt_output = trt_layer.get_output(i) |
| | tensor = Tensor.from_trt(graph, trt_output) |
| | graph._tensors[tensor.name] = tensor |
| |
|
| | |
| | for i in range(trt_network.num_layers): |
| | layer = Layer.from_trt(graph, trt_network.get_layer(i), i) |
| | graph._layers[layer.name] = layer |
| | for i in range(trt_network.num_outputs): |
| | tensor_name = trt_network.get_output(i).name |
| | output_tensor = graph._tensors[tensor_name] |
| | output_tensor.graph_output_index = i |
| | graph._outputs[tensor_name] = output_tensor |
| |
|
| | return graph |
| |
|
| | @staticmethod |
| | def from_network(network: Network, builder_config): |
| | builder_flags = builder_config.trt_builder_config.flags |
| | with current_flags(builder_flags, network.strongly_typed): |
| | graph = PipelineGraph.from_trt(network.trt_network) |
| | graph.infer_shapes(network._generate_optimization_profiles()[-1]) |
| | return graph |
| |
|
| | def assign_shapes(self, shape_info=None, is_partial=False): |
| | if shape_info is None: |
| | for tensor in self.tensors: |
| | tensor.shape = tensor.raw_shape |
| | return |
| | for tensor in self.tensors: |
| | if tensor.name in shape_info.shapes: |
| | tensor.shape = shape_info.shapes[tensor.name] |
| | elif not is_partial: |
| | raise ValueError(f"Cannot find shape for tensor: {tensor.name}") |
| | if shape_info.max_shapes is not None: |
| | if tensor.name in shape_info.max_shapes: |
| | tensor.max_shape = shape_info.max_shapes[tensor.name] |
| | elif not is_partial: |
| | raise ValueError( |
| | f"Cannot find max shape for tensor: {tensor.name}") |
| | if tensor.name in shape_info.values: |
| | tensor.value = shape_info.values[tensor.name] |
| | for layer in self.layers: |
| | if layer.name in shape_info.shape_layers: |
| | layer._is_shape_io = True |
| |
|
| | def infer_shapes(self, profile=None): |
| | from .shape_info import get_shape_info |
| |
|
| | shape_info = get_shape_info(self._trt, profile) |
| | self.assign_shapes(shape_info) |
| |
|
| | def as_trt(self) -> trt.INetworkDefinition: |
| | return self._trt |
| |
|
| | def get_input(self, name: str) -> Tensor: |
| | return self._inputs.get(name) |
| |
|
| | def is_input(self, name: str) -> bool: |
| | return name in self._inputs |
| |
|
| | @property |
| | def inputs(self) -> List[Tensor]: |
| | return [*self._inputs.values()] |
| |
|
| | @property |
| | def num_inputs(self) -> int: |
| | return self._trt.num_inputs |
| |
|
| | def get_output(self, name: str) -> Tensor: |
| | return self._outputs.get(name) |
| |
|
| | def is_output(self, name: str) -> bool: |
| | return name in self._outputs |
| |
|
| | @property |
| | def outputs(self) -> List[Tensor]: |
| | return [*self._outputs.values()] |
| |
|
| | @property |
| | def num_outputs(self) -> int: |
| | return self._trt.num_outputs |
| |
|
| | def get_tensor(self, name: str) -> Tensor: |
| | return self._tensors.get(name) |
| |
|
| | @property |
| | def tensors(self) -> List[Tensor]: |
| | return [*self._tensors.values()] |
| |
|
| | def get_layer(self, name: str) -> Layer: |
| | return self._layers.get(name) |
| |
|
| | @property |
| | def layers(self) -> List[Layer]: |
| | return [*self._layers.values()] |
| |
|
| | @property |
| | def sorted_layers(self) -> List[Layer]: |
| | sorted_layer_ids = get_sorted_layer_ids(self.as_trt()) |
| | return [ |
| | self.get_layer(self.as_trt().get_layer(layer_id).name) |
| | for layer_id in sorted_layer_ids |
| | ] |
| |
|
| | @property |
| | def num_layers(self) -> int: |
| | return self._trt.num_layers |
| |
|
| | def to_dot(self, |
| | path=None, |
| | per_device=False, |
| | per_block=False, |
| | ignore_shape_io=False, |
| | no_style=False, |
| | extra_attrs=None) -> Optional[str]: |
| | ''' |
| | Get a graphviz representation of the graph. |
| | |
| | Parameters: |
| | path: the path to save the graphviz file, if not provided, will return the graphviz source code |
| | ''' |
| | try: |
| | import graphviz |
| | except ImportError: |
| | logger.error( |
| | "Failed to import graphviz, please install graphviz to enable PipelineGraph.to_dot()" |
| | ) |
| | return |
| |
|
| | extra_attrs = extra_attrs or [] |
| |
|
| | graph = graphviz.Digraph() |
| | input_block_graph = graphviz.Digraph(name='cluster_inputs') |
| | output_block_graph = graphviz.Digraph(name='cluster_outputs') |
| | device_graphs = {} |
| | block_graphs = {} |
| | block_graph_mapping = [] |
| | tensor_names = set() |
| | layer_names = set() |
| |
|
| | common_style = dict(fontname='Arial', ) |
| | node_style = dict( |
| | **common_style, |
| | style='rounded,filled,bold', |
| | ) |
| | tensor_style = dict( |
| | **node_style, |
| | shape='ellipse', |
| | fillcolor='white', |
| | ) |
| | input_tensor_style = {**tensor_style, 'fillcolor': 'green'} |
| | output_tensor_style = {**tensor_style, 'fillcolor': 'lightgreen'} |
| | layer_style = dict( |
| | **node_style, |
| | shape='box', |
| | fillcolor='white', |
| | ) |
| | shape_layer_style = {**layer_style, 'fillcolor': 'grey'} |
| | helper_layer_style = {**layer_style, 'fillcolor': 'lightgrey'} |
| | graph_style = dict( |
| | **common_style, |
| | style='rounded', |
| | penwidth='5', |
| | fontsize='28', |
| | ) |
| | device_graph_style = dict( |
| | **graph_style, |
| | color='cornflowerblue', |
| | ) |
| | block_graph_style = dict( |
| | **graph_style, |
| | color='darkcyan', |
| | ) |
| | input_block_style = dict( |
| | **graph_style, |
| | color='green', |
| | ) |
| | output_block_style = dict( |
| | **graph_style, |
| | color='lightgreen', |
| | ) |
| | if no_style: |
| | device_graph_style = {} |
| | block_graph_style = {} |
| | input_block_style = {} |
| | output_block_style = {} |
| | input_block_graph.attr(label='inputs', **input_block_style) |
| | output_block_graph.attr(label='outputs', **output_block_style) |
| |
|
| | def get_tensor_labels(tensor): |
| | labels = [] |
| | if tensor.value is not None: |
| | labels.append(f"value={tensor.value}") |
| | else: |
| | labels.append(f"dtype={tensor.dtype.name}{tensor.shape}") |
| | for attr_name in extra_attrs: |
| | if attr_name in tensor.attrs: |
| | labels.append(f"{attr_name}={tensor.attrs[attr_name]}") |
| | return labels |
| |
|
| | def get_device_graph(name): |
| | if per_device and name.startswith('device'): |
| | device_name = name.split('_')[0] |
| | if device_name not in device_graphs: |
| | device_graph = graphviz.Digraph(name='cluster_' + |
| | device_name) |
| | device_graph.attr(label=device_name, **device_graph_style) |
| | device_graphs[device_name] = device_graph |
| | return device_graphs[device_name] |
| | return None |
| |
|
| | def get_block_graph(layer, current_graph): |
| | if per_block and 'block_id' in layer.attrs: |
| | block_label = f"block{layer.attrs['block_id']}" |
| | if current_graph.name is not None: |
| | graph_label = current_graph.name[len('cluster_'):] |
| | else: |
| | graph_label = '' |
| | block_name = f"{graph_label}{block_label}" |
| | if block_name not in block_graphs: |
| | block_graph = graphviz.Digraph(name='cluster_' + block_name) |
| | block_graph.attr(label=block_label, **block_graph_style) |
| | block_graphs[block_name] = block_graph |
| | block_graph_mapping.append((current_graph, block_graph)) |
| | return block_graphs[block_name] |
| | return current_graph |
| |
|
| | for name, tensor in self._tensors.items(): |
| | style = tensor_style |
| | if tensor.is_graph_input: |
| | style = input_tensor_style |
| | current_graph = input_block_graph |
| | elif tensor.is_graph_output: |
| | style = output_tensor_style |
| | current_graph = output_block_graph |
| | elif tensor.producer.num_outputs == 1: |
| | continue |
| | else: |
| | current_graph = get_device_graph(name) or graph |
| | current_graph = get_block_graph(tensor.producer, current_graph) |
| | if no_style: |
| | style = {} |
| | labels = [name, *get_tensor_labels(tensor)] |
| | content = "\n".join(labels) |
| | current_graph.node(name, content, **style) |
| | tensor_names.add(name) |
| |
|
| | for layer in self.sorted_layers: |
| | name = layer.name |
| |
|
| | style = layer_style |
| | if layer.is_shape_io: |
| | if ignore_shape_io: |
| | continue |
| | style = shape_layer_style |
| | elif layer.attrs.get("role", None) == "helper": |
| | style = helper_layer_style |
| | fillcolor = None |
| | plugin_type = None |
| | if layer.type == trt.LayerType.PLUGIN_V2: |
| | fillcolor = 'yellow' |
| | layer.to_subclass() |
| | plugin_type = layer.as_trt().plugin.plugin_type |
| | layer.to_base_class() |
| | if layer.type == trt.LayerType.MATRIX_MULTIPLY or plugin_type == 'Gemm': |
| | fillcolor = 'orange' |
| | if fillcolor is not None: |
| | style = {**style, 'fillcolor': fillcolor} |
| | if no_style: |
| | style = {} |
| |
|
| | layer_attrs = {} |
| | layer_type = layer.type |
| | layer.to_subclass() |
| | if layer_type == trt.LayerType.CONSTANT: |
| | if not layer.is_shape_io: |
| | if trt.volume(layer.get_output(0).shape) <= 8: |
| | weights = layer.as_trt().weights |
| | if isinstance(weights, trt.Weights): |
| | weights = weights.numpy() |
| | value = np.array2string( |
| | weights, |
| | formatter={'float_kind': lambda x: f"{x:.2e}"}) |
| | layer_attrs['value'] = value |
| | elif layer_type == trt.LayerType.SHUFFLE: |
| | for attr_name in ['first_transpose', 'second_transpose']: |
| | attr_value = getattr(layer.as_trt(), attr_name) |
| | if tuple(attr_value) != (0, 1, 2, 3, 4, 5, 6, 7): |
| | tensor = layer.get_input( |
| | 0 |
| | ) if attr_name == 'first_transpose' else layer.get_output( |
| | 0) |
| | layer_attrs[attr_name] = tuple( |
| | attr_value)[:len(tensor.shape)] |
| | if layer.num_inputs < 2: |
| | attr_value = layer.as_trt().reshape_dims |
| | layer_attrs['reshape_dims'] = attr_value |
| | elif layer_type == trt.LayerType.SLICE: |
| | if layer.num_inputs < 2 or layer.get_input(1) is None: |
| | layer_attrs['start'] = layer.as_trt().start |
| | if layer.num_inputs < 4 or layer.get_input(3) is None: |
| | attr_value = layer.as_trt().stride |
| | if attr_value != tuple( |
| | [1] * len(layer.get_output(0).shape)): |
| | layer_attrs['stride'] = attr_value |
| | layer.to_base_class() |
| |
|
| | if layer.is_shape_io: |
| | labels = [layer.type.name] |
| | else: |
| | labels = [name, layer.type.name] |
| | for key, value in layer_attrs.items(): |
| | labels.append(f"{key}={value}") |
| | for attr_name in extra_attrs: |
| | if attr_name in layer.attrs: |
| | labels.append(f"{attr_name}={layer.attrs[attr_name]}") |
| | if layer.num_outputs == 1: |
| | output = layer.get_output(0) |
| | if output.name != f'{layer.name}_output_0': |
| | labels.append(f"output={output.name}") |
| | labels.extend(get_tensor_labels(output)) |
| | content = "\n".join(labels) |
| |
|
| | current_graph = get_device_graph(name) or graph |
| | current_graph = get_block_graph(layer, current_graph) |
| | current_graph.node(name, content, **style) |
| | layer_names.add(name) |
| |
|
| | for index, input in enumerate(layer.inputs): |
| | if input is not None: |
| | if input.is_graph_input or input.producer.num_outputs > 1: |
| | if input.name in tensor_names: |
| | graph.edge(input.name, name, str(index)) |
| | else: |
| | if input.producer.name in layer_names: |
| | graph.edge(input.producer.name, name, str(index)) |
| | if layer.num_outputs > 1 or (layer.num_outputs == 1 and |
| | layer.get_output(0).is_graph_output): |
| | for index, output in enumerate(layer.outputs): |
| | graph.edge(name, output.name, str(index)) |
| |
|
| | graph.subgraph(input_block_graph) |
| | graph.subgraph(output_block_graph) |
| | for parent_graph, block_graph in block_graph_mapping: |
| | parent_graph.subgraph(block_graph) |
| | for device_graph in device_graphs.values(): |
| | graph.subgraph(device_graph) |
| |
|
| | if not path: |
| | return graph.source |
| | graph.save(path) |
| |
|
| | @staticmethod |
| | def trt_to_dot(trt_network, path=None): |
| | graph = PipelineGraph.from_trt(trt_network) |
| | graph.assign_shapes() |
| | dot = graph.to_dot(no_style=True) |
| | if path is not None: |
| | with open(path, "w") as f: |
| | f.write(dot) |
| | else: |
| | return dot |
| |
|