| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from collections import OrderedDict |
| | from typing import List |
| |
|
| | import tensorrt as trt |
| |
|
| | from ..._common import default_net |
| | from ..._utils import str_dtype_to_trt |
| | from ...functional import (Tensor, arange, concat, expand, |
| | gather_last_token_logits, shape, tanh, unsqueeze) |
| | from ...layers import (Attention, AttentionMaskType, AttentionParams, |
| | ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams, |
| | PositionEmbeddingType, Recurrent, RmsNorm) |
| | from ...module import Module, ModuleList |
| | from ...plugin import current_all_reduce_helper |
| | from ..generation_mixin import GenerationMixin |
| | from ..modeling_utils import PretrainedConfig, PretrainedModel |
| |
|
| |
|
| | class ResidualLayer(Module): |
| |
|
| | def __init__(self, config: PretrainedConfig, layer_idx: int): |
| | super().__init__() |
| | layer_type_len = len(config.layer_types) |
| | self.temporal_block_type = config.layer_types[layer_idx % |
| | layer_type_len] |
| |
|
| | self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size, |
| | eps=config.norm_epsilon, |
| | dtype=config.dtype) |
| |
|
| | if self.temporal_block_type == 'recurrent': |
| | self.recurrent = Recurrent(width=config.hidden_size, |
| | lru_width=config.rnn_hidden_size, |
| | d_conv=config.conv_kernel, |
| | num_heads=config.num_attention_heads, |
| | dtype=config.dtype, |
| | tp_group=config.mapping.tp_group, |
| | tp_size=config.mapping.tp_size) |
| | elif self.temporal_block_type == 'attention': |
| | layer_types = config.layer_types * ( |
| | (layer_idx + 1) // layer_type_len) |
| | layer_types = layer_types + config.layer_types[0:( |
| | (layer_idx + 1) % layer_type_len)] |
| | attention_layer_idx = layer_types.count('attention') - 1 |
| | self.attention = Attention( |
| | local_layer_idx=attention_layer_idx, |
| | hidden_size=config.hidden_size, |
| | num_attention_heads=config.num_attention_heads, |
| | num_kv_heads=config.num_key_value_heads, |
| | dtype=config.dtype, |
| | attention_mask_type=AttentionMaskType.causal, |
| | position_embedding_type=PositionEmbeddingType.rope_gpt_neox, |
| | rotary_embedding_percentage=config.rotary_pct, |
| | tp_group=config.mapping.tp_group, |
| | tp_size=config.mapping.tp_size, |
| | tp_rank=config.mapping.tp_rank, |
| | quant_mode=config.quant_mode, |
| | bias=False, |
| | dense_bias=True) |
| | else: |
| | raise ValueError( |
| | 'RecurrentGemma only support "recurrent" and "attention" blocks.' |
| | ) |
| |
|
| | self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size, |
| | eps=config.norm_epsilon, |
| | dtype=config.dtype) |
| |
|
| | self.mlp = GatedMLP(hidden_size=config.hidden_size, |
| | ffn_hidden_size=config.intermediate_size, |
| | hidden_act=config.hidden_act, |
| | dtype=config.dtype, |
| | tp_group=config.mapping.tp_group, |
| | tp_size=config.mapping.tp_size, |
| | quant_mode=config.quant_mode) |
| |
|
| | def forward(self, |
| | hidden_states, |
| | use_cache=False, |
| | attention_mask=None, |
| | kv_cache_params=None, |
| | attention_params=None, |
| | conv_state=None, |
| | lru_state=None, |
| | host_request_types=None, |
| | last_token_ids=None, |
| | host_context_lengths=None, |
| | slot_mapping=None, |
| | conv_indices=None): |
| |
|
| | residual = hidden_states |
| |
|
| | hidden_states = self.input_layernorm(hidden_states) |
| |
|
| | if self.temporal_block_type == 'recurrent': |
| | temporal_output, present_conv, present_lru = self.recurrent( |
| | hidden_states, |
| | conv_state=conv_state, |
| | lru_state=lru_state, |
| | host_request_types=host_request_types, |
| | last_token_ids=last_token_ids, |
| | host_context_lengths=host_context_lengths, |
| | slot_mapping=slot_mapping, |
| | conv_indices=conv_indices, |
| | ) |
| | else: |
| | present_conv, present_lru = None, None |
| |
|
| | if self.temporal_block_type == 'attention': |
| | temporal_output, present_kv = self.attention( |
| | hidden_states, |
| | attention_mask=attention_mask, |
| | use_cache=use_cache, |
| | kv_cache_params=kv_cache_params, |
| | attention_params=attention_params) |
| | else: |
| | present_kv = None |
| |
|
| | hidden_states = residual + temporal_output |
| |
|
| | residual = hidden_states |
| | hidden_states = self.post_layernorm(hidden_states) |
| | hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + hidden_states |
| |
|
| | return hidden_states, present_kv, present_conv, present_lru |
| |
|
| |
|
| | class RecurrentGemmaModel(Module): |
| |
|
| | def __init__(self, config: PretrainedConfig) -> None: |
| | super().__init__() |
| | self.d_conv = config.conv_kernel |
| | self.lru_width = config.rnn_hidden_size |
| | n_layer = config.num_hidden_layers |
| |
|
| | self.vocab_embedding = Embedding(config.vocab_size, |
| | config.hidden_size, |
| | dtype=config.dtype) |
| | self.layers = ModuleList( |
| | [ResidualLayer(config, layer_idx=i) for i in range(n_layer)]) |
| |
|
| | self.ln_f = RmsNorm(normalized_shape=config.hidden_size, |
| | eps=config.norm_epsilon, |
| | dtype=config.dtype) |
| |
|
| | def forward(self, |
| | input_ids, |
| | use_cache=False, |
| | attention_mask=None, |
| | kv_cache_params=None, |
| | attention_params=None, |
| | conv_states=None, |
| | lru_states=None, |
| | host_request_types=None, |
| | last_token_ids=None, |
| | host_context_lengths=None, |
| | slot_mapping=None): |
| |
|
| | hidden_states = self.vocab_embedding(input_ids) |
| |
|
| | |
| | indices = None |
| | if not default_net().plugin_config.mamba_conv1d_plugin: |
| | batch_size = shape(input_ids, 0) |
| | indices = expand( |
| | unsqueeze(arange(0, self.d_conv - 1, dtype='int32'), 0), |
| | concat([batch_size, self.d_conv - 1])) |
| | offsets = expand(unsqueeze(last_token_ids, 1), |
| | concat([batch_size, self.d_conv - 1])) |
| | indices = unsqueeze(indices + offsets, 1) |
| | indices = expand( |
| | indices, concat([batch_size, self.lru_width, self.d_conv - 1])) |
| |
|
| | present_kvs, present_convs, present_lrus = [], [], [] |
| | for layer, past_kv, past_conv, past_lru in zip( |
| | self.layers, kv_cache_params.past_key_value, conv_states, |
| | lru_states): |
| | hidden_states, present_kv, present_conv, present_lru = layer( |
| | hidden_states, |
| | use_cache, |
| | attention_mask, |
| | kv_cache_params=KeyValueCacheParams( |
| | past_key_value=[past_kv], |
| | host_past_key_value_lengths=kv_cache_params. |
| | host_past_key_value_lengths, |
| | host_max_attention_window_sizes=kv_cache_params. |
| | host_max_attention_window_sizes, |
| | host_sink_token_length=kv_cache_params. |
| | host_sink_token_length, |
| | kv_cache_block_offsets=kv_cache_params. |
| | kv_cache_block_offsets, |
| | host_kv_cache_block_offsets=kv_cache_params. |
| | host_kv_cache_block_offsets, |
| | host_kv_cache_pool_pointers=kv_cache_params. |
| | host_kv_cache_pool_pointers, |
| | cache_indirection=kv_cache_params.cache_indirection), |
| | attention_params=attention_params, |
| | conv_state=past_conv, |
| | lru_state=past_lru, |
| | host_request_types=host_request_types, |
| | last_token_ids=last_token_ids, |
| | host_context_lengths=host_context_lengths, |
| | slot_mapping=slot_mapping, |
| | conv_indices=indices) |
| | present_kvs.append(present_kv) |
| | present_convs.append(present_conv) |
| | present_lrus.append(present_lru) |
| |
|
| | hidden_states = self.ln_f(hidden_states) |
| | return hidden_states, tuple(present_kvs), tuple(present_convs), tuple( |
| | present_lrus) |
| |
|
| |
|
| | class RecurrentGemmaForCausalLM(PretrainedModel): |
| |
|
| | def __init__(self, config: PretrainedConfig): |
| | super().__init__(config) |
| | dtype = config.dtype |
| | logits_dtype = config.logits_dtype |
| | if isinstance(dtype, str): |
| | self.dtype = str_dtype_to_trt(dtype) |
| | else: |
| | assert isinstance(dtype, trt.DataType) |
| | self.dtype = dtype |
| |
|
| | assert len(config.layer_types) > 0 |
| | layer_types = config.layer_types |
| | layer_types = layer_types * (config.num_hidden_layers // |
| | len(layer_types)) |
| | layer_types = layer_types + layer_types[0:(config.num_hidden_layers % |
| | len(layer_types))] |
| | self.layer_types = layer_types |
| |
|
| | self.config = config |
| | self.gather_context_logits = False |
| | self.logits_soft_cap = config.logits_soft_cap |
| |
|
| | if isinstance(logits_dtype, str): |
| | self._logits_dtype = str_dtype_to_trt(logits_dtype) |
| | else: |
| | assert isinstance(logits_dtype, trt.DataType) |
| | self._logits_dtype = logits_dtype |
| |
|
| | self.transformer = RecurrentGemmaModel(config) |
| | self.lm_head = ColumnLinear(config.hidden_size, |
| | config.vocab_size, |
| | bias=False, |
| | dtype=dtype, |
| | tp_group=config.mapping.tp_group, |
| | tp_size=config.mapping.tp_size, |
| | gather_output=True) |
| |
|
| | def forward(self, |
| | input_ids, |
| | position_ids=None, |
| | use_cache=False, |
| | attention_mask=None, |
| | kv_cache_params=None, |
| | attention_params=None, |
| | conv_states=None, |
| | rnn_states=None, |
| | host_request_types=None, |
| | last_token_ids=None, |
| | last_token_ids_for_logits=None, |
| | host_context_lengths=None, |
| | slot_mapping=None): |
| | hidden_states, present_kvs, present_convs, present_rnns = self.transformer( |
| | input_ids, use_cache, attention_mask, kv_cache_params, |
| | attention_params, conv_states, rnn_states, host_request_types, |
| | last_token_ids, host_context_lengths, slot_mapping) |
| |
|
| | if not self.gather_context_logits: |
| | hidden_states = gather_last_token_logits( |
| | hidden_states, last_token_ids_for_logits, |
| | default_net().plugin_config.remove_input_padding) |
| |
|
| | lm_logits = self.lm_head(hidden_states) |
| | lm_logits = tanh( |
| | lm_logits / self.logits_soft_cap) * self.logits_soft_cap |
| | lm_logits.mark_output('logits', self._logits_dtype) |
| | if not default_net().plugin_config.paged_kv_cache: |
| | for i, present_kv in enumerate(present_kvs): |
| | if present_kv is not None: |
| | present_kv.mark_output(f'present_key_value_{i}', self.dtype) |
| |
|
| | if not default_net().plugin_config.paged_state: |
| | for i, present_conv in enumerate(present_convs): |
| | if present_conv is not None: |
| | present_conv.mark_output(f'present_conv_state_{i}', |
| | self.dtype) |
| | for i, present_rnn in enumerate(present_rnns): |
| | if present_rnn is not None: |
| | present_rnn.mark_output(f'present_rnn_state_{i}', |
| | str_dtype_to_trt('float32')) |
| |
|
| | return (lm_logits, present_kvs, present_convs, present_rnns) |
| |
|
| | def prepare_recurrent_inputs(self, max_batch_size, num_profiles, mapping): |
| | use_mamba_conv1d_plugin = default_net( |
| | ).plugin_config.mamba_conv1d_plugin |
| |
|
| | default_range = GenerationMixin.default_range |
| | batch_range = [default_range(max_batch_size)] * num_profiles |
| |
|
| | conv_states = [] |
| | rnn_states = [] |
| | dim = self.config.rnn_hidden_size // mapping.tp_size |
| | if use_mamba_conv1d_plugin: |
| | conv_state_dim_range = OrderedDict([ |
| | ('batch_size', batch_range), |
| | ('kernel_size', [self.config.conv_kernel - 1] * num_profiles), |
| | ('dim_size', [dim] * num_profiles), |
| | ]) |
| | else: |
| | conv_state_dim_range = OrderedDict([ |
| | ('batch_size', batch_range), |
| | ('dim_size', [dim] * num_profiles), |
| | ('kernel_size', [self.config.conv_kernel - 1] * num_profiles), |
| | ]) |
| |
|
| | rnn_state_dim_range = OrderedDict([ |
| | ('batch_size', batch_range), |
| | ('state_size', [1] * num_profiles), |
| | ('dim_size', [dim] * num_profiles), |
| | ]) |
| | one_dim_range = OrderedDict([ |
| | ('buffer_count', [1] * num_profiles), |
| | ]) |
| |
|
| | for i in range(self.config.num_hidden_layers): |
| | if self.layer_types[i] == 'recurrent': |
| | if default_net().plugin_config.paged_state: |
| | conv_state = Tensor(name=f'conv_state_ptr_{i}', |
| | dtype=str_dtype_to_trt('int64'), |
| | shape=[1], |
| | dim_range=one_dim_range) |
| |
|
| | rnn_state = Tensor(name=f'rnn_state_ptr_{i}', |
| | dtype=str_dtype_to_trt('int64'), |
| | shape=[1], |
| | dim_range=one_dim_range) |
| | else: |
| | if use_mamba_conv1d_plugin: |
| | conv_state = Tensor( |
| | name=f'past_conv_state_{i}', |
| | dtype=self.dtype, |
| | shape=[-1, self.config.conv_kernel - 1, dim], |
| | dim_range=conv_state_dim_range) |
| | else: |
| | conv_state = Tensor( |
| | name=f'past_conv_state_{i}', |
| | dtype=self.dtype, |
| | shape=[-1, dim, self.config.conv_kernel - 1], |
| | dim_range=conv_state_dim_range) |
| |
|
| | rnn_state = Tensor(name=f'past_rnn_state_{i}', |
| | dtype=str_dtype_to_trt('float32'), |
| | shape=[-1, 1, dim], |
| | dim_range=rnn_state_dim_range) |
| | else: |
| | conv_state, rnn_state = None, None |
| | conv_states.append(conv_state) |
| | rnn_states.append(rnn_state) |
| |
|
| | slot_mapping = None |
| | if default_net().plugin_config.paged_state: |
| | slot_mapping = Tensor( |
| | name='slot_mapping', |
| | dtype=trt.int32, |
| | shape=[-1], |
| | dim_range=OrderedDict([('batch_size', batch_range)]), |
| | ) |
| |
|
| | return_dict = { |
| | 'conv_states': conv_states, |
| | 'rnn_states': rnn_states, |
| | 'slot_mapping': slot_mapping, |
| | } |
| | return return_dict |
| |
|
| | def prepare_inputs( |
| | self, |
| | max_batch_size, |
| | max_input_len, |
| | max_seq_len, |
| | max_num_tokens, |
| | use_cache, |
| | max_beam_width: int = 1, |
| | opt_num_tokens: int = None, |
| | opt_batch_size: int = 0, |
| | prompt_embedding_table_size: int = 0, |
| | max_draft_len: int = 0, |
| | gather_context_logits: bool = False, |
| | gather_generation_logits: bool = False, |
| | lora_target_modules: List[str] = None, |
| | speculative_decoding_draft_tokens_external: bool = False): |
| | '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the |
| | ranges of the dimensions of when using TRT dynamic shapes. |
| | |
| | @return: a list contains values which can be fed into the self.forward() |
| | ''' |
| | assert speculative_decoding_draft_tokens_external == False, \ |
| | "We don't support speculative decoding for the RecurrentGemma model." |
| | assert max_beam_width == 1, "We don't support beam search for the RecurrentGemma model." |
| |
|
| | remove_input_padding = default_net().plugin_config.remove_input_padding |
| | use_gpt_attention_plugin = default_net( |
| | ).plugin_config.gpt_attention_plugin |
| | use_gemm_plugin = default_net().plugin_config.gemm_plugin |
| | paged_kv_cache = default_net().plugin_config.paged_kv_cache |
| | tokens_per_block = default_net().plugin_config.tokens_per_block |
| | multiple_profiles = default_net().plugin_config.multiple_profiles |
| | streamingllm = default_net().plugin_config.streamingllm |
| | use_mamba_conv1d_plugin = default_net( |
| | ).plugin_config.mamba_conv1d_plugin |
| |
|
| | self.gather_context_logits = gather_context_logits |
| | mapping = self.config.mapping |
| |
|
| | |
| | enable_ctx_gen_opt_profiles = GenerationMixin.has_ctx_gen_opt_profiles( |
| | use_gpt_attention_plugin, use_gemm_plugin, remove_input_padding, |
| | paged_kv_cache) |
| | num_profiles, ranges = GenerationMixin.get_profiles_ranges( |
| | max_batch_size=max_batch_size, |
| | max_beam_width=max_beam_width, |
| | max_input_len=max_input_len, |
| | max_num_tokens=max_num_tokens, |
| | max_draft_len=max_draft_len, |
| | opt_batch_size=opt_batch_size, |
| | opt_num_tokens=opt_num_tokens, |
| | enable_ctx_gen_opt_profiles=enable_ctx_gen_opt_profiles, |
| | multiple_profiles=multiple_profiles) |
| |
|
| | if remove_input_padding: |
| | assert use_mamba_conv1d_plugin, "mamba_conv1d_plugin is needed to support remove_input_padding" |
| | input_ids = Tensor(name='input_ids', |
| | dtype=trt.int32, |
| | shape=[-1], |
| | dim_range=OrderedDict([ |
| | ('num_tokens', ranges['num_tokens_range']), |
| | ])) |
| | position_ids = Tensor(name='position_ids', |
| | dtype=trt.int32, |
| | shape=[-1], |
| | dim_range=OrderedDict([ |
| | ('position_ids_num_tokens_range', |
| | ranges['num_tokens_range']), |
| | ])) |
| | else: |
| | input_ids = Tensor(name='input_ids', |
| | dtype=trt.int32, |
| | shape=[-1, -1], |
| | dim_range=OrderedDict([ |
| | ('batch_size_beam_width', |
| | ranges['bb_range']), |
| | ('input_len', ranges['inlen_range']), |
| | ])) |
| | position_ids = Tensor(name='position_ids', |
| | dtype=trt.int32, |
| | shape=[-1, -1], |
| | dim_range=OrderedDict([ |
| | ('batch_size_beam_width', |
| | ranges['bb_range']), |
| | ('position_ids_inlen_range', |
| | ranges['position_ids_inlen_range']), |
| | ])) |
| | if mapping.tp_size > 1: |
| | current_all_reduce_helper().set_workspace_tensor( |
| | mapping, num_profiles) |
| |
|
| | |
| | num_attention_layers = self.layer_types.count('attention') |
| | attn_layer_idx = [] |
| | for i in range(self.config.num_hidden_layers): |
| | if self.layer_types[i] == 'attention': |
| | attn_layer_idx.append(i) |
| | attention_inputs = self.prepare_attention_inputs( |
| | max_batch_size=max_batch_size, |
| | max_beam_width=max_beam_width, |
| | max_input_len=max_input_len, |
| | max_seq_len=max_seq_len, |
| | num_kv_heads=self.config.num_key_value_heads, |
| | head_size=self.config.head_size, |
| | num_layers=num_attention_layers, |
| | kv_dtype=str_dtype_to_trt(self.config.kv_dtype), |
| | num_profiles=num_profiles, |
| | enable_ctx_gen_opt_profiles=enable_ctx_gen_opt_profiles, |
| | remove_input_padding=remove_input_padding, |
| | use_gpt_attention_plugin=use_gpt_attention_plugin, |
| | paged_kv_cache=paged_kv_cache, |
| | tokens_per_block=tokens_per_block, |
| | mapping=mapping, |
| | streamingllm=streamingllm, |
| | attn_layer_idx=attn_layer_idx) |
| |
|
| | kv_idx = 0 |
| | past_key_value = [] |
| | for i in range(self.config.num_hidden_layers): |
| | if self.layer_types[i] == 'attention' and not paged_kv_cache: |
| | past_key_value.append( |
| | attention_inputs['past_key_value'][kv_idx]) |
| | kv_idx += 1 |
| | else: |
| | past_key_value.append(None) |
| | attention_inputs['past_key_value'] = past_key_value |
| |
|
| | |
| | recurrent_inputs = self.prepare_recurrent_inputs( |
| | max_batch_size=max_batch_size, |
| | num_profiles=num_profiles, |
| | mapping=mapping, |
| | ) |
| |
|
| | if use_gpt_attention_plugin: |
| | host_request_types = attention_inputs['host_request_types'] |
| | else: |
| | host_request_types = Tensor( |
| | name='host_request_types', |
| | dtype=trt.int32, |
| | shape=[-1], |
| | dim_range=OrderedDict([('batch_size_beam_width', |
| | ranges['bb_range'])]), |
| | ) |
| |
|
| | last_token_ids = Tensor( |
| | name='last_token_ids', |
| | dtype=trt.int32, |
| | shape=[-1], |
| | dim_range=OrderedDict([ |
| | ('batch_size_last_token_ids', ranges['bbd_range']), |
| | ]), |
| | ) |
| | last_token_ids_for_logits = None |
| | if not gather_context_logits: |
| | last_token_ids_for_logits = last_token_ids |
| |
|
| | if use_gpt_attention_plugin and remove_input_padding: |
| | host_context_lengths = attention_inputs['host_context_lengths'] |
| | elif remove_input_padding: |
| | host_context_lengths = Tensor( |
| | name='host_context_lengths', |
| | dtype=trt.int32, |
| | shape=[-1], |
| | dim_range=OrderedDict([('batch_size_beam_width', |
| | ranges['bb_range'])]), |
| | ) |
| | else: |
| | host_context_lengths = None |
| |
|
| | return_dict = { |
| | 'input_ids': |
| | input_ids, |
| | 'position_ids': |
| | position_ids, |
| | 'use_cache': |
| | True, |
| | 'attention_mask': |
| | attention_inputs['attention_mask'], |
| | 'kv_cache_params': |
| | KeyValueCacheParams( |
| | past_key_value=attention_inputs['past_key_value'], |
| | host_past_key_value_lengths=attention_inputs[ |
| | 'host_past_key_value_lengths'], |
| | host_max_attention_window_sizes=attention_inputs[ |
| | 'host_max_attention_window_sizes'], |
| | host_sink_token_length=attention_inputs[ |
| | 'host_sink_token_length'], |
| | kv_cache_block_offsets=attention_inputs[ |
| | 'kv_cache_block_offsets'], |
| | host_kv_cache_block_offsets=attention_inputs[ |
| | 'host_kv_cache_block_offsets'], |
| | host_kv_cache_pool_pointers=attention_inputs[ |
| | 'host_kv_cache_pool_pointers'], |
| | cache_indirection=attention_inputs['cache_indirection'], |
| | ), |
| | 'attention_params': |
| | AttentionParams( |
| | sequence_length=attention_inputs['sequence_length'], |
| | context_lengths=attention_inputs['context_lengths'], |
| | host_context_lengths=attention_inputs['host_context_lengths'], |
| | max_context_length=max_input_len, |
| | host_request_types=attention_inputs['host_request_types'], |
| | host_runtime_perf_knobs=attention_inputs[ |
| | 'host_runtime_perf_knobs']), |
| | 'conv_states': |
| | recurrent_inputs['conv_states'], |
| | 'rnn_states': |
| | recurrent_inputs['rnn_states'], |
| | 'host_request_types': |
| | host_request_types, |
| | 'last_token_ids': |
| | last_token_ids, |
| | 'last_token_ids_for_logits': |
| | last_token_ids_for_logits, |
| | 'host_context_lengths': |
| | host_context_lengths, |
| | 'slot_mapping': |
| | recurrent_inputs['slot_mapping'], |
| | } |
| | return return_dict |
| |
|