|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
from pathlib import Path |
|
|
from typing import List, Optional, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
import tensorrt_llm.bindings.executor as trtllm |
|
|
|
|
|
from .. import profiler |
|
|
from ..bindings import DataType, GptJsonConfig, ModelConfig, WorldConfig |
|
|
from ..logger import logger |
|
|
from ..mapping import Mapping |
|
|
from .generation import LogitsProcessor, SamplingConfig, StoppingCriteria |
|
|
from .model_runner import ModelRunnerMixin |
|
|
|
|
|
_bindings_dtype_to_torch_dtype_dict = { |
|
|
DataType.FLOAT: torch.float, |
|
|
DataType.HALF: torch.half, |
|
|
DataType.INT8: torch.int8, |
|
|
DataType.INT32: torch.int32, |
|
|
DataType.BOOL: torch.bool, |
|
|
DataType.UINT8: torch.uint8, |
|
|
DataType.BF16: torch.bfloat16, |
|
|
DataType.INT64: torch.int64 |
|
|
} |
|
|
|
|
|
|
|
|
class ModelRunnerCpp(ModelRunnerMixin): |
|
|
""" |
|
|
An interface class that wraps Executor and provides generation methods. |
|
|
""" |
|
|
|
|
|
def __init__(self, executor: trtllm.Executor, max_batch_size: int, |
|
|
max_input_len: int, max_seq_len: int, max_beam_width: int, |
|
|
model_config: ModelConfig, world_config: WorldConfig) -> None: |
|
|
self.session = executor |
|
|
self.max_batch_size = max_batch_size |
|
|
self.max_input_len = max_input_len |
|
|
self.max_seq_len = max_seq_len |
|
|
self.max_beam_width = max_beam_width |
|
|
self.model_config = model_config |
|
|
self.mapping = Mapping(world_size=world_config.tensor_parallelism * |
|
|
world_config.pipeline_parallelism, |
|
|
rank=world_config.rank, |
|
|
gpus_per_node=world_config.gpus_per_node, |
|
|
tp_size=world_config.tensor_parallelism, |
|
|
pp_size=world_config.pipeline_parallelism) |
|
|
self.world_config = world_config |
|
|
|
|
|
@classmethod |
|
|
def from_dir(cls, |
|
|
engine_dir: str, |
|
|
*, |
|
|
lora_dir: Optional[str] = None, |
|
|
rank: int = 0, |
|
|
max_batch_size: Optional[int] = None, |
|
|
max_input_len: Optional[int] = None, |
|
|
max_output_len: Optional[int] = None, |
|
|
max_beam_width: Optional[int] = None, |
|
|
max_attention_window_size: Optional[int] = None, |
|
|
sink_token_length: Optional[int] = None, |
|
|
kv_cache_free_gpu_memory_fraction: Optional[float] = None, |
|
|
medusa_choices: list[list[int]] | None = None, |
|
|
debug_mode: bool = False, |
|
|
lora_ckpt_source: str = "hf", |
|
|
gpu_weights_percent: float = 1, |
|
|
max_tokens_in_paged_kv_cache: int | None = None, |
|
|
kv_cache_enable_block_reuse: bool = False, |
|
|
enable_chunked_context: bool = False, |
|
|
is_enc_dec: bool = False, |
|
|
multi_block_mode: Optional[bool] = None) -> 'ModelRunnerCpp': |
|
|
""" |
|
|
Create a ModelRunnerCpp instance from an engine directory. |
|
|
|
|
|
Args: |
|
|
engine_dir (str): |
|
|
The directory that contains the serialized engine files and config files. |
|
|
lora_dir (str): |
|
|
The directory that contains LoRA weights. |
|
|
rank (int): |
|
|
The runtime rank id. |
|
|
max_batch_size (int): |
|
|
The runtime batch size limit. If max_batch_size is not None, it should not |
|
|
be larger than the engine's max_batch_size; otherwise, the engine's max_batch_size |
|
|
will be used. |
|
|
max_input_len (int): |
|
|
The runtime input length limit. If max_input_len is not None, it should not |
|
|
be larger than the engine's max_input_len; otherwise, the engine's max_input_len |
|
|
will be used. |
|
|
max_output_len (int): |
|
|
The runtime output length limit. If max_output_len is not None, it should not |
|
|
be larger than the engine's max_output_len; otherwise, the engine's max_output_len |
|
|
will be used. |
|
|
max_beam_width (int): |
|
|
The runtime beam width limit. If max_beam_width is not None, it should not |
|
|
be larger than the engine's max_beam_width; otherwise, the engine's max_beam_width |
|
|
will be used. |
|
|
max_attention_window_size (int): |
|
|
The attention window size that controls the sliding window attention / cyclic kv cache behavior. |
|
|
sink_token_length (int) : |
|
|
The sink token length, default=0. |
|
|
kv_cache_free_gpu_memory_fraction (float) : |
|
|
Free GPU memory fraction that KV cache used. |
|
|
debug_mode (bool): |
|
|
Whether or not to turn on the debug mode. |
|
|
medusa_choices (List[List[int]]): |
|
|
Medusa choices to use when in Medusa decoding. |
|
|
lora_ckpt_source (str): |
|
|
Source of checkpoint. Should be one of ['hf', 'nemo']. |
|
|
max_tokens_in_paged_kv_cache (int): |
|
|
Maximum amount of tokens configured in kv cache. |
|
|
kv_cache_enable_block_reuse (bool): |
|
|
Enables block reuse in kv cache. |
|
|
enable_chunked_context (bool): |
|
|
Enables chunked context. |
|
|
is_enc_dec (bool): |
|
|
Whether the model is encoder-decoder architecture. |
|
|
multi_block_mode (bool): |
|
|
Whether to distribute the work across multiple CUDA thread-blocks on the GPU for masked MHA kernel. |
|
|
Returns: |
|
|
ModelRunnerCpp: An instance of ModelRunnerCpp. |
|
|
""" |
|
|
|
|
|
if is_enc_dec: |
|
|
encoder_config_path = Path(engine_dir) / "encoder" / "config.json" |
|
|
encoder_json_config = GptJsonConfig.parse_file(encoder_config_path) |
|
|
encoder_json_config.model_config |
|
|
decoder_config_path = Path(engine_dir) / "decoder" / "config.json" |
|
|
decoder_json_config = GptJsonConfig.parse_file(decoder_config_path) |
|
|
decoder_model_config = decoder_json_config.model_config |
|
|
|
|
|
tp_size = decoder_json_config.tensor_parallelism |
|
|
pp_size = decoder_json_config.pipeline_parallelism |
|
|
gpus_per_node = decoder_json_config.gpus_per_node |
|
|
world_config = WorldConfig.mpi(tensor_parallelism=tp_size, |
|
|
pipeline_parallelism=pp_size, |
|
|
gpus_per_node=gpus_per_node) |
|
|
assert rank == world_config.rank |
|
|
|
|
|
profiler.start('load tensorrt_llm engine') |
|
|
|
|
|
kv_cache_config = trtllm.KvCacheConfig( |
|
|
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction / |
|
|
2, |
|
|
max_attention_window=max_attention_window_size, |
|
|
sink_token_length=sink_token_length) |
|
|
|
|
|
executor = trtllm.Executor( |
|
|
Path(engine_dir) / "encoder", |
|
|
Path(engine_dir) / "decoder", trtllm.ModelType.ENCODER_DECODER, |
|
|
trtllm.ExecutorConfig(max_beam_width=max_beam_width, |
|
|
kv_cache_config=kv_cache_config, |
|
|
gpu_weights_percent=gpu_weights_percent)) |
|
|
|
|
|
profiler.stop('load tensorrt_llm engine') |
|
|
|
|
|
loading_time = profiler.elapsed_time_in_sec( |
|
|
"load tensorrt_llm engine") |
|
|
logger.info(f'Load engine takes: {loading_time} sec') |
|
|
|
|
|
return cls(executor, |
|
|
max_batch_size=max_batch_size, |
|
|
max_input_len=max_input_len, |
|
|
max_seq_len=max_input_len + max_output_len, |
|
|
max_beam_width=max_beam_width, |
|
|
model_config=decoder_model_config, |
|
|
world_config=world_config) |
|
|
|
|
|
config_path = Path(engine_dir) / "config.json" |
|
|
json_config = GptJsonConfig.parse_file(config_path) |
|
|
model_config = json_config.model_config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tp_size = json_config.tensor_parallelism |
|
|
pp_size = json_config.pipeline_parallelism |
|
|
gpus_per_node = json_config.gpus_per_node |
|
|
world_config = WorldConfig.mpi(tensor_parallelism=tp_size, |
|
|
pipeline_parallelism=pp_size, |
|
|
gpus_per_node=gpus_per_node) |
|
|
assert rank == world_config.rank |
|
|
|
|
|
profiler.start('load tensorrt_llm engine') |
|
|
|
|
|
kv_cache_config = trtllm.KvCacheConfig( |
|
|
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, |
|
|
max_attention_window=max_attention_window_size, |
|
|
sink_token_length=sink_token_length, |
|
|
max_tokens=max_tokens_in_paged_kv_cache, |
|
|
enable_block_reuse=kv_cache_enable_block_reuse) |
|
|
|
|
|
decoding_config = trtllm.DecodingConfig() |
|
|
if medusa_choices is not None: |
|
|
decoding_config.medusa_choices = medusa_choices |
|
|
if multi_block_mode is not None: |
|
|
multi_block_mode = False |
|
|
|
|
|
if max_batch_size is None: |
|
|
max_batch_size = model_config.max_batch_size |
|
|
else: |
|
|
assert max_batch_size <= model_config.max_batch_size |
|
|
if max_input_len is None: |
|
|
max_input_len = model_config.max_input_len |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if max_output_len is None: |
|
|
max_seq_len = model_config.max_seq_len |
|
|
else: |
|
|
max_seq_len = max_input_len + max_output_len |
|
|
assert max_seq_len <= model_config.max_seq_len |
|
|
if max_beam_width is None: |
|
|
max_beam_width = model_config.max_beam_width |
|
|
else: |
|
|
assert max_beam_width <= model_config.max_beam_width |
|
|
|
|
|
trtllm_config = trtllm.ExecutorConfig( |
|
|
max_beam_width=max_beam_width, |
|
|
kv_cache_config=kv_cache_config, |
|
|
decoding_config=decoding_config, |
|
|
gpu_weights_percent=gpu_weights_percent) |
|
|
trtllm_config.enable_chunked_context = enable_chunked_context |
|
|
if multi_block_mode is not None: |
|
|
trtllm_config.multi_block_mode = multi_block_mode |
|
|
executor = trtllm.Executor(engine_dir, trtllm.ModelType.DECODER_ONLY, |
|
|
trtllm_config) |
|
|
|
|
|
profiler.stop('load tensorrt_llm engine') |
|
|
|
|
|
loading_time = profiler.elapsed_time_in_sec("load tensorrt_llm engine") |
|
|
logger.info(f'Load engine takes: {loading_time} sec') |
|
|
|
|
|
return cls(executor, |
|
|
max_batch_size=max_batch_size, |
|
|
max_input_len=max_input_len, |
|
|
max_seq_len=max_seq_len, |
|
|
max_beam_width=max_beam_width, |
|
|
model_config=model_config, |
|
|
world_config=world_config) |
|
|
|
|
|
def _check_inputs(self, batch_input_ids: List[List[int]], |
|
|
sampling_config: trtllm.SamplingConfig, max_new_tokens): |
|
|
batch_size = len(batch_input_ids) |
|
|
if batch_size > self.max_batch_size: |
|
|
raise RuntimeError( |
|
|
f"Input batch size ({batch_size}) exceeds the engine or specified limit ({self.max_batch_size})" |
|
|
) |
|
|
input_lengths = [len(x) for x in batch_input_ids] |
|
|
max_length = max(input_lengths) |
|
|
if max_length > self.max_input_len: |
|
|
raise RuntimeError( |
|
|
f"Maximum input length ({max_length}) exceeds the engine or specified limit ({self.max_input_len})" |
|
|
) |
|
|
if max_length + max_new_tokens > self.max_seq_len: |
|
|
raise RuntimeError( |
|
|
f"Maximum input length ({max_length}) + maximum new tokens ({max_new_tokens}) exceeds the engine or specified limit ({self.max_seq_len})" |
|
|
) |
|
|
if sampling_config.beam_width > self.max_beam_width: |
|
|
raise RuntimeError( |
|
|
f"Num beams ({sampling_config.beam_width}) exceeds the engine or specified limit ({self.max_beam_width})" |
|
|
) |
|
|
|
|
|
@property |
|
|
def dtype(self) -> torch.dtype: |
|
|
bindings_dtype = self.model_config.data_type |
|
|
return _bindings_dtype_to_torch_dtype_dict[bindings_dtype] |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
return self.model_config.vocab_size |
|
|
|
|
|
@property |
|
|
def vocab_size_padded(self) -> int: |
|
|
return self.model_config.vocab_size_padded(self.world_config.size) |
|
|
|
|
|
@property |
|
|
def hidden_size(self) -> int: |
|
|
return self.model_config.hidden_size |
|
|
|
|
|
@property |
|
|
def num_heads(self) -> int: |
|
|
return self.model_config.num_heads |
|
|
|
|
|
@property |
|
|
def num_layers(self) -> int: |
|
|
return self.model_config.num_layers( |
|
|
self.world_config.pipeline_parallelism) |
|
|
|
|
|
@property |
|
|
def max_sequence_length(self) -> int: |
|
|
return self.max_seq_len |
|
|
|
|
|
@property |
|
|
def remove_input_padding(self) -> bool: |
|
|
return self.model_config.use_packed_input |
|
|
|
|
|
@property |
|
|
def max_prompt_embedding_table_size(self) -> int: |
|
|
return self.model_config.max_prompt_embedding_table_size |
|
|
|
|
|
@property |
|
|
def gather_context_logits(self) -> bool: |
|
|
return self.model_config.compute_context_logits |
|
|
|
|
|
@property |
|
|
def gather_generation_logits(self) -> bool: |
|
|
return self.model_config.compute_generation_logits |
|
|
|
|
|
def generate(self, |
|
|
batch_input_ids: List[torch.Tensor], |
|
|
*, |
|
|
encoder_input_ids: List[torch.Tensor] = None, |
|
|
sampling_config: Optional[SamplingConfig] = None, |
|
|
lora_uids: Optional[list] = None, |
|
|
streaming: bool = False, |
|
|
stopping_criteria: Optional[StoppingCriteria] = None, |
|
|
logits_processor: Optional[LogitsProcessor] = None, |
|
|
max_new_tokens: int = 1, |
|
|
end_id: int | None = None, |
|
|
pad_id: int | None = None, |
|
|
bad_words_list: list[list[int]] | None = None, |
|
|
stop_words_list: list[list[int]] | None = None, |
|
|
return_dict: bool = False, |
|
|
output_sequence_lengths: bool = False, |
|
|
output_log_probs: bool = False, |
|
|
output_cum_log_probs: bool = False, |
|
|
prompt_table: Optional[Union[str, torch.Tensor]] = None, |
|
|
prompt_tasks: Optional[str] = None, |
|
|
return_all_generated_tokens: bool = False, |
|
|
**kwargs) -> Union[torch.Tensor, dict]: |
|
|
""" |
|
|
Generates sequences of token ids. |
|
|
The generation-controlling parameters are set in the sampling_config; it will be set to a default one if not passed. |
|
|
You can override any sampling_config's attributes by passing corresponding parameters. |
|
|
|
|
|
Args: |
|
|
batch_input_ids (List[torch.Tensor]): |
|
|
A list of input id tensors. Each tensor is of shape (sequence_length, ). |
|
|
sampling_config (SamplingConfig): |
|
|
The sampling configuration to be used as base parametrization for the generation call. |
|
|
The passed **kwargs matching the sampling_config's attributes will override them. |
|
|
If the sampling_config is not provided, a default will be used. |
|
|
prompt_table (str or torch.Tensor): |
|
|
The file path of prompt table (.npy format, exported by nemo_prompt_convert.py) or the prompt table itself. |
|
|
prompt_tasks (str): |
|
|
The prompt tuning task ids for the input batch, in format of comma-separated list (e.g., 0,3,1,0). |
|
|
lora_uids (list): |
|
|
The uids of LoRA weights for the input batch. Use -1 to disable the LoRA module. |
|
|
streaming (bool): |
|
|
Whether or not to use streaming mode for generation. |
|
|
stopping_criteria (StoppingCriteria): |
|
|
Custom stopping criteria. |
|
|
logits_processor (LogitsProcessor): |
|
|
Custom logits processors. |
|
|
return_all_generated_tokens (bool): |
|
|
Whether the full output is returned at each streaming step |
|
|
kwargs (Dict[str, Any]: |
|
|
Ad hoc parametrization of sampling_config. |
|
|
The passed **kwargs matching the sampling_config's attributes will override them. |
|
|
Returns: |
|
|
torch.Tensor or dict: |
|
|
If return_dict=False, the method returns generated output_ids. |
|
|
If return_dict=True, the method returns a dict of output_ids, |
|
|
sequence_lengths (if sampling_config.output_sequence_lengths=True), |
|
|
context_logits and generation_logits (if self.gather_context_logits=True and |
|
|
self.gather_generation_logits=True, respectively). |
|
|
""" |
|
|
|
|
|
if lora_uids is not None: |
|
|
raise RuntimeError("LoRA is not supported in C++ session.") |
|
|
if stopping_criteria is not None: |
|
|
raise RuntimeError( |
|
|
"Stopping criteria is not supported in C++ session.") |
|
|
if logits_processor is not None: |
|
|
raise RuntimeError( |
|
|
"Logits processor is not supported in C++ session.") |
|
|
|
|
|
|
|
|
if not self.session.can_enqueue_requests(): |
|
|
return [] |
|
|
|
|
|
|
|
|
batch_input_ids_list = [a.tolist() for a in batch_input_ids] |
|
|
encoder_input_ids_list = [a.tolist() for a in encoder_input_ids |
|
|
] if encoder_input_ids else None |
|
|
|
|
|
if sampling_config is None: |
|
|
|
|
|
|
|
|
accepted_parameters = [ |
|
|
"num_beams", "top_k", "top_p", "top_p_min", "top_p_reset_ids", |
|
|
"top_p_decay", "random_seed", "temperature", "min_length", |
|
|
"beam_search_diversity_rate", "repetition_penalty", |
|
|
"presence_penalty", "frequency_penalty", "length_penalty", |
|
|
"early_stopping", "no_repeat_ngram_size" |
|
|
] |
|
|
rename_params = {"num_beams": "beam_width"} |
|
|
sampling_params = { |
|
|
k: v |
|
|
for k, v in kwargs.items() if k in accepted_parameters |
|
|
} |
|
|
for k, v in rename_params.items(): |
|
|
if k in sampling_params: |
|
|
sampling_params[v] = sampling_params.pop(k) |
|
|
if "top_p" in sampling_params and sampling_params["top_p"] == 0.0: |
|
|
sampling_params["top_p"] = None |
|
|
|
|
|
sampling_config = trtllm.SamplingConfig(**sampling_params) |
|
|
else: |
|
|
sampling_config = copy.deepcopy(sampling_config) |
|
|
|
|
|
self._check_inputs( |
|
|
encoder_input_ids_list if encoder_input_ids else |
|
|
batch_input_ids_list, sampling_config, max_new_tokens) |
|
|
|
|
|
output_config = trtllm.OutputConfig( |
|
|
return_context_logits=self.gather_context_logits, |
|
|
return_generation_logits=self.gather_generation_logits, |
|
|
return_log_probs=output_log_probs, |
|
|
) |
|
|
|
|
|
prompt_tuning_configs = self._prepare_ptuning_executor( |
|
|
batch_input_ids_list, prompt_table, prompt_tasks) |
|
|
|
|
|
stop_words_list = self._prepare_words_list(stop_words_list, |
|
|
len(batch_input_ids_list)) |
|
|
bad_words_list = self._prepare_words_list(bad_words_list, |
|
|
len(batch_input_ids_list)) |
|
|
|
|
|
requests = [ |
|
|
trtllm.Request( |
|
|
input_token_ids=input_ids, |
|
|
encoder_input_token_ids=encoder_input_ids_list[i] |
|
|
if encoder_input_ids is not None else None, |
|
|
max_new_tokens=max_new_tokens, |
|
|
pad_id=pad_id, |
|
|
end_id=end_id, |
|
|
stop_words=stop_words, |
|
|
bad_words=bad_words, |
|
|
sampling_config=sampling_config, |
|
|
streaming=streaming, |
|
|
output_config=output_config, |
|
|
prompt_tuning_config=prompt_tuning_config, |
|
|
return_all_generated_tokens=return_all_generated_tokens) |
|
|
for i, (input_ids, stop_words, bad_words, |
|
|
prompt_tuning_config) in enumerate( |
|
|
zip(batch_input_ids_list, stop_words_list, |
|
|
bad_words_list, prompt_tuning_configs)) |
|
|
] |
|
|
|
|
|
request_ids = self.session.enqueue_requests(requests) |
|
|
|
|
|
if not streaming: |
|
|
return self._initialize_and_fill_output( |
|
|
request_ids, end_id, return_dict, output_sequence_lengths, |
|
|
output_log_probs, output_cum_log_probs, batch_input_ids, |
|
|
streaming, return_all_generated_tokens) |
|
|
else: |
|
|
return self._stream(request_ids, end_id, return_dict, |
|
|
output_sequence_lengths, output_log_probs, |
|
|
output_cum_log_probs, batch_input_ids, |
|
|
streaming, batch_input_ids_list, |
|
|
return_all_generated_tokens) |
|
|
|
|
|
def _prepare_words_list(self, words_list: List[List[List[int]]], |
|
|
batch_size: int): |
|
|
if words_list is None: |
|
|
return [None] * batch_size |
|
|
return words_list |
|
|
|
|
|
def _prepare_ptuning_executor(self, batch_input_ids_list, prompt_table, |
|
|
prompt_tasks): |
|
|
prompt_tuning_configs = len(batch_input_ids_list) * [None] |
|
|
if prompt_table is not None: |
|
|
prompt_table_data = self._prepare_embedding_table( |
|
|
prompt_table).cuda() |
|
|
if prompt_tasks is not None: |
|
|
task_indices = [int(t) for t in prompt_tasks.split(',')] |
|
|
assert len(task_indices) == len(batch_input_ids_list), \ |
|
|
f"Number of supplied tasks ({len(task_indices)}) must match input batch size ({len(batch_input_ids_list)})" |
|
|
prompt_tuning_configs = [ |
|
|
trtllm.PromptTuningConfig( |
|
|
embedding_table=prompt_table_data[task_indices[i]]) |
|
|
for i in range(len(batch_input_ids_list)) |
|
|
] |
|
|
else: |
|
|
prompt_tuning_configs = [ |
|
|
trtllm.PromptTuningConfig( |
|
|
embedding_table=prompt_table_data[0]) |
|
|
for _ in range(len(batch_input_ids_list)) |
|
|
] |
|
|
return prompt_tuning_configs |
|
|
|
|
|
def _initialize_and_fill_output(self, request_ids, end_id, return_dict, |
|
|
output_sequence_lengths, output_log_probs, |
|
|
output_cum_log_probs, batch_input_ids, |
|
|
streaming, return_all_generated_tokens): |
|
|
output_ids = [[] for _ in range(len(request_ids))] |
|
|
for reqid_pos in range(len(request_ids)): |
|
|
output_ids[reqid_pos] = [[] for _ in range(self.max_beam_width)] |
|
|
|
|
|
multi_responses = self.session.await_responses(request_ids) |
|
|
responses = [ |
|
|
response for responses in multi_responses for response in responses |
|
|
] |
|
|
|
|
|
return self._fill_output(responses, output_ids, end_id, return_dict, |
|
|
output_sequence_lengths, output_log_probs, |
|
|
output_cum_log_probs, batch_input_ids, |
|
|
streaming, request_ids, |
|
|
return_all_generated_tokens) |
|
|
|
|
|
def _stream(self, request_ids, end_id, return_dict, output_sequence_lengths, |
|
|
output_log_probs, output_cum_log_probs, batch_input_ids, |
|
|
streaming, batch_input_ids_list, return_all_generated_tokens): |
|
|
output_ids = [[] for _ in range(len(request_ids))] |
|
|
for reqid_pos in range(len(request_ids)): |
|
|
output_ids[reqid_pos] = [ |
|
|
copy.deepcopy(batch_input_ids_list[reqid_pos]) |
|
|
for _ in range(self.max_beam_width) |
|
|
] |
|
|
|
|
|
finished_reqs = 0 |
|
|
while finished_reqs < len(request_ids): |
|
|
responses = self.session.await_responses() |
|
|
|
|
|
for response in responses: |
|
|
if response.result.is_final: |
|
|
finished_reqs += 1 |
|
|
|
|
|
yield self._fill_output(responses, output_ids, end_id, return_dict, |
|
|
output_sequence_lengths, output_log_probs, |
|
|
output_cum_log_probs, batch_input_ids, |
|
|
streaming, request_ids, |
|
|
return_all_generated_tokens) |
|
|
|
|
|
def _fill_output(self, responses, output_ids, end_id, return_dict, |
|
|
output_sequence_lengths, output_log_probs, |
|
|
output_cum_log_probs, batch_input_ids, streaming, |
|
|
request_ids, return_all_generated_tokens): |
|
|
cuda_device = torch.device("cuda") |
|
|
|
|
|
for response in responses: |
|
|
if response.has_error(): |
|
|
raise RuntimeError(response.error_msg) |
|
|
|
|
|
reqid_pos = request_ids.index(response.request_id) |
|
|
for beam, output_tokens in enumerate( |
|
|
response.result.output_token_ids): |
|
|
if return_all_generated_tokens: |
|
|
output_ids[reqid_pos][beam] = output_tokens |
|
|
else: |
|
|
output_ids[reqid_pos][beam] += output_tokens |
|
|
|
|
|
sequence_lengths = [] |
|
|
for output in output_ids: |
|
|
sequence_lengths.append([len(a) for a in output]) |
|
|
|
|
|
if streaming: |
|
|
output_ids = copy.deepcopy(output_ids) |
|
|
|
|
|
for beam in output_ids: |
|
|
for output_tokens in beam: |
|
|
output_tokens += (self.max_seq_len - |
|
|
len(output_tokens)) * [end_id] |
|
|
|
|
|
output_ids = torch.tensor(output_ids, |
|
|
dtype=torch.int32, |
|
|
device=cuda_device) |
|
|
|
|
|
if return_dict: |
|
|
outputs = {'output_ids': output_ids} |
|
|
if output_sequence_lengths: |
|
|
outputs['sequence_lengths'] = torch.tensor(sequence_lengths, |
|
|
dtype=torch.int32, |
|
|
device=cuda_device) |
|
|
if self.gather_context_logits: |
|
|
outputs['context_logits'] = [ |
|
|
a.result.context_logits.cuda() for a in responses |
|
|
if a.result.context_logits is not None |
|
|
] |
|
|
|
|
|
max_input_length = max(a.shape[0] |
|
|
for a in outputs['context_logits']) |
|
|
for i, a in enumerate(outputs['context_logits']): |
|
|
pad_length = max_input_length - a.shape[0] |
|
|
outputs['context_logits'][i] = torch.nn.functional.pad( |
|
|
a, [0, 0, 0, pad_length]) |
|
|
outputs['context_logits'] = torch.stack( |
|
|
outputs['context_logits']) |
|
|
if self.gather_generation_logits: |
|
|
outputs['generation_logits'] = [ |
|
|
a.result.generation_logits.cuda() for a in responses |
|
|
if a.result.generation_logits is not None |
|
|
] |
|
|
outputs['generation_logits'] = torch.stack( |
|
|
outputs['generation_logits']) |
|
|
if output_log_probs: |
|
|
outputs['log_probs'] = [ |
|
|
a.result.log_probs for a in responses |
|
|
if a.result.log_probs is not None |
|
|
] |
|
|
|
|
|
max_seq_len = max( |
|
|
len(a) for beam_list in outputs['log_probs'] |
|
|
for a in beam_list) |
|
|
for i, a in enumerate(outputs['log_probs']): |
|
|
for j, b in enumerate(a): |
|
|
pad_length = max_seq_len - len(b) |
|
|
outputs['log_probs'][i][j] = b + [0.0] * pad_length |
|
|
outputs['log_probs'] = torch.tensor(outputs['log_probs'], |
|
|
device=cuda_device) |
|
|
if output_cum_log_probs: |
|
|
outputs['cum_log_probs'] = [ |
|
|
a.result.cum_log_probs for a in responses |
|
|
if a.result.cum_log_probs is not None |
|
|
] |
|
|
outputs['cum_log_probs'] = torch.tensor( |
|
|
outputs['cum_log_probs'], device=cuda_device) |
|
|
input_lengths = torch.tensor([x.size(0) for x in batch_input_ids], |
|
|
dtype=torch.int32, |
|
|
device=cuda_device) |
|
|
outputs = self._prepare_outputs(outputs, input_lengths) |
|
|
else: |
|
|
outputs = output_ids |
|
|
return outputs |
|
|
|