|
import datetime |
|
import json |
|
import os |
|
import time |
|
from threading import Lock, Thread |
|
|
|
import numpy as np |
|
import triton_python_backend_utils as pb_utils |
|
from torch import from_numpy |
|
|
|
import tensorrt_llm.bindings.executor as trtllm |
|
|
|
|
|
def get_input_tensor_by_name(request, name): |
|
tensor = pb_utils.get_input_tensor_by_name(request, name) |
|
if tensor is None: |
|
return None |
|
return tensor.as_numpy() |
|
|
|
|
|
def get_input_scalar_by_name(request, name): |
|
tensor = get_input_tensor_by_name(request, name) |
|
if tensor is None: |
|
return None |
|
if tensor.size != 1: |
|
raise pb_utils.TritonModelException( |
|
f"Expected a single value for {name}") |
|
return tensor.item() |
|
|
|
|
|
def read_parameter_as_type(value, name, pytype=str): |
|
if value == "": |
|
return None |
|
if value.startswith("${") and value.endswith("}"): |
|
return None |
|
if pytype is bool: |
|
return value.lower() in ["1", "true"] |
|
try: |
|
result = pytype(value) |
|
return result |
|
except: |
|
pb_utils.Logger.log_warning( |
|
f"Could not read parameter '{name}' with value '{value}', will use default." |
|
) |
|
return None |
|
|
|
|
|
def get_parameter(model_config, name, pytype=str): |
|
if name not in model_config['parameters']: |
|
return None |
|
return read_parameter_as_type( |
|
model_config['parameters'][name]['string_value'], name, pytype) |
|
|
|
|
|
def convert_word_list(word_list): |
|
if word_list is None: |
|
return None |
|
word_list = word_list.tolist() |
|
if len(word_list) == 0 or len(word_list[0]) != 2: |
|
raise pb_utils.TritonModelException(f"Invalid format for word list.") |
|
words, indices = word_list[0] |
|
result = [] |
|
current_index = 0 |
|
for i in indices: |
|
if i == -1: |
|
continue |
|
if i > len(words): |
|
raise pb_utils.TritonModelException( |
|
f"Invalid format for word list.") |
|
current_word = [] |
|
while current_index < i: |
|
current_word.append(words[current_index]) |
|
current_index += 1 |
|
result.append(current_word) |
|
return result |
|
|
|
|
|
def parse_medusa_choices(medusa_choices): |
|
if medusa_choices is None: |
|
return None |
|
try: |
|
result = json.loads( |
|
"[" + medusa_choices.replace("{", "[").replace("}", "]") + "]") |
|
assert isinstance(result, list) and len(result) > 0 |
|
assert all([isinstance(x, list) for x in result]) |
|
assert all([isinstance(y, int) for x in result for y in x]) |
|
except Exception: |
|
raise pb_utils.TritonModelException( |
|
"Invalid format for medusa_choices") |
|
return result |
|
|
|
|
|
def get_sampling_config_from_request(request): |
|
kwargs = {} |
|
kwargs['beam_width'] = get_input_scalar_by_name(request, 'beam_width') or 1 |
|
kwargs['top_k'] = get_input_scalar_by_name(request, 'runtime_top_k') |
|
kwargs['top_p'] = get_input_scalar_by_name(request, 'runtime_top_p') |
|
kwargs['top_p'] = None if kwargs['top_p'] is None or kwargs[ |
|
'top_p'] <= 0 else kwargs['top_p'] |
|
kwargs['random_seed'] = get_input_scalar_by_name(request, 'random_seed') |
|
kwargs['temperature'] = get_input_scalar_by_name(request, 'temperature') |
|
kwargs['min_length'] = get_input_scalar_by_name(request, 'min_length') |
|
kwargs['repetition_penalty'] = get_input_scalar_by_name( |
|
request, 'repetition_penalty') |
|
kwargs['presence_penalty'] = get_input_scalar_by_name( |
|
request, 'presence_penalty') |
|
kwargs['frequency_penalty'] = get_input_scalar_by_name( |
|
request, 'frequency_penalty') |
|
kwargs['length_penalty'] = get_input_scalar_by_name(request, 'len_penalty') |
|
kwargs['top_p_min'] = get_input_scalar_by_name(request, |
|
'runtime_top_p_min') |
|
kwargs['top_p_reset_ids'] = get_input_scalar_by_name( |
|
request, 'runtime_top_p_reset_ids') |
|
kwargs['top_p_decay'] = get_input_scalar_by_name(request, |
|
'runtime_top_p_decay') |
|
kwargs['beam_search_diversity_rate'] = get_input_scalar_by_name( |
|
request, 'beam_search_diversity_rate') |
|
kwargs['early_stopping'] = get_input_scalar_by_name( |
|
request, 'early_stopping') |
|
kwargs = {k: v for k, v in kwargs.items() if v is not None} |
|
return trtllm.SamplingConfig(**kwargs) |
|
|
|
|
|
def get_output_config_from_request(request, exclude_input_from_output): |
|
kwargs = {} |
|
kwargs["return_log_probs"] = get_input_scalar_by_name( |
|
request, 'return_log_probs') |
|
kwargs["return_context_logits"] = get_input_scalar_by_name( |
|
request, 'return_context_logits') |
|
kwargs["return_generation_logits"] = get_input_scalar_by_name( |
|
request, 'return_generation_logits') |
|
kwargs["exclude_input_from_output"] = exclude_input_from_output |
|
kwargs = {k: v for k, v in kwargs.items() if v is not None} |
|
return trtllm.OutputConfig(**kwargs) |
|
|
|
|
|
def get_external_draft_tokens_config_from_request(request): |
|
kwargs = {} |
|
draft_input_ids = get_input_tensor_by_name(request, 'draft_input_ids') |
|
if draft_input_ids is not None: |
|
kwargs['tokens'] = draft_input_ids.tolist() |
|
draft_logits = get_input_tensor_by_name(request, 'draft_logits') |
|
if draft_logits is not None: |
|
kwargs['logits'] = from_numpy(draft_logits) |
|
kwargs['acceptance_threshold'] = get_input_scalar_by_name( |
|
request, 'draft_acceptance_threshold') |
|
kwargs = {k: v for k, v in kwargs.items() if v is not None} |
|
if len(kwargs) > 0: |
|
return trtllm.ExternalDraftTokensConfig(**kwargs) |
|
return None |
|
|
|
|
|
def get_prompt_tuning_config_from_request(request): |
|
|
|
kwargs = {} |
|
prompt_embedding_table = get_input_tensor_by_name( |
|
request, 'prompt_embedding_table') |
|
if prompt_embedding_table is not None: |
|
kwargs["embedding_table"] = from_numpy(prompt_embedding_table) |
|
kwargs = {k: v for k, v in kwargs.items() if v is not None} |
|
if len(kwargs) > 0: |
|
return trtllm.PromptTuningConfig(**kwargs) |
|
return None |
|
|
|
|
|
def get_lora_config_from_request(request): |
|
kwargs = {} |
|
kwargs["task_id"] = get_input_scalar_by_name(request, 'lora_task_id') |
|
lora_weights = get_input_tensor_by_name(request, 'lora_weights') |
|
if lora_weights is not None: |
|
kwargs["weights"] = from_numpy(lora_weights) |
|
lora_config = get_input_tensor_by_name(request, 'lora_config') |
|
if lora_config is not None: |
|
kwargs["config"] = from_numpy(lora_config) |
|
kwargs = {k: v for k, v in kwargs.items() if v is not None} |
|
if len(kwargs) > 0: |
|
return trtllm.LoraConfig(**kwargs) |
|
return None |
|
|
|
|
|
def convert_request(request, exclude_input_from_output, decoupled): |
|
inputs = {} |
|
input_token_ids = get_input_tensor_by_name(request, 'input_ids') |
|
if input_token_ids is None: |
|
raise pb_utils.TritonModelException( |
|
"A value is required for input_ids") |
|
input_token_ids = input_token_ids.tolist() |
|
if len(input_token_ids) == 0: |
|
raise pb_utils.TritonModelException(f"Invalid format for input_ids") |
|
inputs['input_token_ids'] = input_token_ids[0] |
|
|
|
inputs['max_new_tokens'] = get_input_scalar_by_name( |
|
request, 'request_output_len') |
|
if inputs['max_new_tokens'] is None: |
|
raise pb_utils.TritonModelException( |
|
"A value is required for request_output_len") |
|
inputs['streaming'] = get_input_scalar_by_name(request, 'streaming') |
|
if inputs['streaming'] and not decoupled: |
|
raise pb_utils.TritonModelException( |
|
"Streaming is only supported in decoupled mode.") |
|
inputs['end_id'] = get_input_scalar_by_name(request, 'end_id') |
|
inputs['pad_id'] = get_input_scalar_by_name(request, 'pad_id') |
|
inputs['stop_words'] = convert_word_list( |
|
get_input_tensor_by_name(request, 'stop_words_list')) |
|
inputs['bad_words'] = convert_word_list( |
|
get_input_tensor_by_name(request, 'bad_words_list')) |
|
embedding_bias = get_input_tensor_by_name(request, 'embedding_bias') |
|
if embedding_bias is not None and embedding_bias.size != 0: |
|
inputs['embedding_bias'] = from_numpy(embedding_bias).squeeze() |
|
|
|
sampling_config = get_sampling_config_from_request(request) |
|
output_config = get_output_config_from_request(request, |
|
exclude_input_from_output) |
|
external_draft_tokens_config = get_external_draft_tokens_config_from_request( |
|
request) |
|
prompt_tuning_config = get_prompt_tuning_config_from_request(request) |
|
lora_config = get_lora_config_from_request(request) |
|
|
|
return trtllm.Request( |
|
**inputs, |
|
sampling_config=sampling_config, |
|
output_config=output_config, |
|
external_draft_tokens_config=external_draft_tokens_config, |
|
prompt_tuning_config=prompt_tuning_config, |
|
lora_config=lora_config, |
|
) |
|
|
|
|
|
def convert_response(response): |
|
if response.has_error(): |
|
return pb_utils.InferenceResponse(output_tensors=[], |
|
error=pb_utils.TritonError( |
|
response.error_msg)), True |
|
result = response.result |
|
beam_lengths = np.expand_dims( |
|
np.array([len(beam) for beam in result.output_token_ids], np.int32), 0) |
|
max_beam_length = max([len(beam) for beam in result.output_token_ids]) |
|
output_ids = np.full((1, len(result.output_token_ids), max_beam_length), |
|
-1, np.int32) |
|
for idx, beam in enumerate(result.output_token_ids): |
|
output_ids[0, idx, :len(beam)] = beam |
|
output_tensors = [ |
|
pb_utils.Tensor("output_ids", output_ids), |
|
pb_utils.Tensor("sequence_length", beam_lengths), |
|
] |
|
output_tensors.append( |
|
pb_utils.Tensor( |
|
"cum_log_probs", |
|
np.expand_dims(np.array(result.cum_log_probs, np.float32), 0) |
|
if result.cum_log_probs is not None else np.zeros( |
|
(1, 1), np.float32))) |
|
output_tensors.append( |
|
pb_utils.Tensor( |
|
"output_log_probs", |
|
np.expand_dims(np.array(result.log_probs, np.float32), 0) if |
|
result.log_probs is not None else np.zeros((1, 1, 1), np.float32))) |
|
output_tensors.append( |
|
pb_utils.Tensor( |
|
"context_logits", |
|
np.expand_dims(np.array(result.context_logits, np.float32), 0) |
|
if result.context_logits is not None else np.zeros( |
|
(1, 1, 1), np.float32))) |
|
output_tensors.append( |
|
pb_utils.Tensor( |
|
"generation_logits", |
|
np.expand_dims(np.array(result.generation_logits, np.float32), 0) |
|
if result.generation_logits is not None else np.zeros( |
|
(1, 1, 1, 1), np.float32))) |
|
return pb_utils.InferenceResponse(output_tensors), result.is_final |
|
|
|
|
|
def convert_scheduler_policy(batch_scheduler_policy: str): |
|
if batch_scheduler_policy.lower() == "max_utilization": |
|
return trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION |
|
elif batch_scheduler_policy.lower() == "guaranteed_no_evict": |
|
return trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT |
|
raise pb_utils.TritonModelException( |
|
f"batch_scheduler_policy value of '{batch_scheduler_policy}' is not supported." |
|
) |
|
|
|
|
|
def convert_batching_type(gpt_model_type: str): |
|
if gpt_model_type is None: |
|
return None |
|
if gpt_model_type.lower( |
|
) == "inflight_fused_batching" or gpt_model_type.lower( |
|
) == "inflight_batching": |
|
return trtllm.BatchingType.INFLIGHT |
|
elif gpt_model_type.lower() == "v1": |
|
return trtllm.BatchingType.STATIC |
|
raise pb_utils.TritonModelException( |
|
f"gpt_model_type value of '{gpt_model_type}' is not supported.") |
|
|
|
|
|
def convert_decoding_mode(decoding_mode: str): |
|
if decoding_mode is None: |
|
return None |
|
elif decoding_mode == "auto": |
|
return trtllm.DecodingMode.Auto() |
|
elif decoding_mode == "top_k": |
|
return trtllm.DecodingMode.TopK() |
|
elif decoding_mode == "top_p": |
|
return trtllm.DecodingMode.TopP() |
|
elif decoding_mode == "top_k_top_p": |
|
return trtllm.DecodingMode.TopKTopP() |
|
elif decoding_mode == "beam_search": |
|
return trtllm.DecodingMode.BeamSearch() |
|
elif decoding_mode == "medusa": |
|
return trtllm.DecodingMode.Medusa() |
|
raise pb_utils.TritonModelException( |
|
f"decoding_mode value of '{decoding_mode}' is not supported.") |
|
|
|
|
|
def convert_timestamp_to_seconds(timestamp: str): |
|
return int( |
|
datetime.datetime.strptime(timestamp, "%m-%d-%Y %H:%M:%S").timestamp()) |
|
|
|
|
|
class TritonPythonModel: |
|
"""Your Python model must use the same class name. Every Python model |
|
that is created must have "TritonPythonModel" as the class name. |
|
""" |
|
|
|
def get_scheduler_config(self, model_config): |
|
batch_scheduler_policy = get_parameter(model_config, |
|
"batch_scheduler_policy") |
|
if batch_scheduler_policy is None: |
|
return trtllm.SchedulerConfig() |
|
return trtllm.SchedulerConfig( |
|
convert_scheduler_policy(batch_scheduler_policy)) |
|
|
|
def get_kv_cache_config(self, model_config): |
|
kwargs = { |
|
"enable_block_reuse": |
|
get_parameter(model_config, "enable_kv_cache_reuse", bool), |
|
"max_tokens": |
|
get_parameter(model_config, "max_tokens_in_paged_kv_cache", int), |
|
"sink_token_length": |
|
get_parameter(model_config, "sink_token_length", int), |
|
"max_attention_window": |
|
get_parameter(model_config, "max_attention_window_size", int), |
|
"free_gpu_memory_fraction": |
|
get_parameter(model_config, "kv_cache_free_gpu_mem_fraction", |
|
float), |
|
"host_cache_size": |
|
get_parameter(model_config, "kv_cache_host_memory_bytes", int), |
|
"onboard_blocks": |
|
get_parameter(model_config, "kv_cache_onboard_blocks", bool), |
|
} |
|
kwargs = {k: v for k, v in kwargs.items() if v is not None} |
|
return trtllm.KvCacheConfig(**kwargs) |
|
|
|
def get_parallel_config(self, model_config): |
|
kwargs = {} |
|
gpu_device_ids = get_parameter(model_config, "gpu_device_ids") |
|
if gpu_device_ids: |
|
kwargs["device_ids"] = [int(x) for x in gpu_device_ids.split(",")] |
|
self.use_orchestrator_mode = os.environ.get("TRTLLM_ORCHESTRATOR", |
|
"0") == "1" |
|
if self.use_orchestrator_mode: |
|
kwargs[ |
|
"communication_mode"] = trtllm.CommunicationMode.ORCHESTRATOR |
|
worker_path = get_parameter(model_config, "worker_path") |
|
if worker_path is not None: |
|
raise pb_utils.TritonModelException( |
|
"worker_path parameter is specified, but this is no longer supported. Please specify executor_worker_path instead to specify the location of the trtllmExecutorWorker executable." |
|
) |
|
executor_worker_path = get_parameter(model_config, |
|
"executor_worker_path") |
|
kwargs["orchestrator_config"] = trtllm.OrchestratorConfig( |
|
True, executor_worker_path) |
|
if len(kwargs) > 0: |
|
return trtllm.ParallelConfig(**kwargs) |
|
return None |
|
|
|
def get_peft_cache_config(self, model_config): |
|
kwargs = { |
|
"optimal_adapter_size": |
|
get_parameter(model_config, "lora_cache_optimal_adapter_size", |
|
int), |
|
"max_adapter_size": |
|
get_parameter(model_config, "lora_cache_max_adapter_size", int), |
|
"device_cache_percent": |
|
get_parameter(model_config, "lora_cache_gpu_memory_fraction", |
|
float), |
|
"host_cache_size": |
|
get_parameter(model_config, "lora_cache_host_memory_bytes", int), |
|
} |
|
kwargs = {k: v for k, v in kwargs.items() if v is not None} |
|
return trtllm.PeftCacheConfig(**kwargs) |
|
|
|
def get_decoding_config(self, model_config): |
|
kwargs = { |
|
"medusa_choices": |
|
parse_medusa_choices(get_parameter(model_config, |
|
"medusa_choices")), |
|
"decoding_mode": |
|
convert_decoding_mode(get_parameter(model_config, |
|
"decoding_mode")), |
|
} |
|
print(kwargs) |
|
kwargs = {k: v for k, v in kwargs.items() if v is not None} |
|
return trtllm.DecodingConfig(**kwargs) |
|
|
|
def get_executor_config(self, model_config): |
|
kwargs = { |
|
"max_beam_width": |
|
get_parameter(model_config, "max_beam_width", int), |
|
"scheduler_config": |
|
self.get_scheduler_config(model_config), |
|
"kv_cache_config": |
|
self.get_kv_cache_config(model_config), |
|
"enable_chunked_context": |
|
get_parameter(model_config, "enable_chunked_context", bool), |
|
"normalize_log_probs": |
|
get_parameter(model_config, "normalize_log_probs", bool), |
|
"batching_type": |
|
convert_batching_type(get_parameter(model_config, |
|
"gpt_model_type")), |
|
"parallel_config": |
|
self.get_parallel_config(model_config), |
|
"peft_cache_config": |
|
self.get_peft_cache_config(model_config), |
|
"decoding_config": |
|
self.get_decoding_config(model_config), |
|
} |
|
kwargs = {k: v for k, v in kwargs.items() if v is not None} |
|
return trtllm.ExecutorConfig(**kwargs) |
|
|
|
def create_metrics(self, model: str, version: str, is_v1_model: bool): |
|
self.request_metric_family = pb_utils.MetricFamily( |
|
name="nv_trt_llm_request_metrics", |
|
description="TRT LLM request metrics", |
|
kind=pb_utils.MetricFamily.GAUGE, |
|
) |
|
self.runtime_memory_metric_family = pb_utils.MetricFamily( |
|
name="nv_trt_llm_runtime_memory_metrics", |
|
description="TRT LLM runtime memory metrics", |
|
kind=pb_utils.MetricFamily.GAUGE, |
|
) |
|
self.kv_cache_metric_family = pb_utils.MetricFamily( |
|
name="nv_trt_llm_kv_cache_block_metrics", |
|
description="TRT LLM KV cache block metrics", |
|
kind=pb_utils.MetricFamily.GAUGE, |
|
) |
|
model_type = "v1" if is_v1_model else "inflight_batcher" |
|
self.model_type_metric_family = pb_utils.MetricFamily( |
|
name=f"nv_trt_llm_{model_type}_metrics", |
|
description=f"TRT LLM {model_type}-specific metrics", |
|
kind=pb_utils.MetricFamily.GAUGE, |
|
) |
|
self.general_metric_family = pb_utils.MetricFamily( |
|
name="nv_trt_llm_general_metrics", |
|
description="General TRT LLM metrics", |
|
kind=pb_utils.MetricFamily.GAUGE, |
|
) |
|
common_labels = {"model": model, "version": version} |
|
self.all_metrics = { |
|
|
|
"num_active_requests": |
|
self.request_metric_family.Metric(labels={ |
|
"request_type": "active", |
|
**common_labels |
|
}), |
|
"max_num_active_requests": |
|
self.request_metric_family.Metric(labels={ |
|
"request_type": "max", |
|
**common_labels |
|
}), |
|
"num_scheduled_requests": |
|
self.request_metric_family.Metric(labels={ |
|
"request_type": "scheduled", |
|
**common_labels |
|
}), |
|
"num_context_requests": |
|
self.request_metric_family.Metric(labels={ |
|
"request_type": "context", |
|
**common_labels |
|
}), |
|
|
|
"cpu_mem_usage": |
|
self.runtime_memory_metric_family.Metric(labels={ |
|
"memory_type": "cpu", |
|
**common_labels |
|
}), |
|
"gpu_mem_usage": |
|
self.runtime_memory_metric_family.Metric(labels={ |
|
"memory_type": "gpu", |
|
**common_labels |
|
}), |
|
"pinned_mem_usage": |
|
self.runtime_memory_metric_family.Metric(labels={ |
|
"memory_type": "pinned", |
|
**common_labels |
|
}), |
|
|
|
"max_num_blocks": |
|
self.kv_cache_metric_family.Metric(labels={ |
|
"kv_cache_block_type": "max", |
|
**common_labels |
|
}), |
|
"free_num_blocks": |
|
self.kv_cache_metric_family.Metric(labels={ |
|
"kv_cache_block_type": "free", |
|
**common_labels |
|
}), |
|
"used_num_blocks": |
|
self.kv_cache_metric_family.Metric(labels={ |
|
"kv_cache_block_type": "used", |
|
**common_labels |
|
}), |
|
"tokens_per_block": |
|
self.kv_cache_metric_family.Metric(labels={ |
|
"kv_cache_block_type": "tokens_per", |
|
**common_labels |
|
}), |
|
|
|
"timestamp": |
|
self.general_metric_family.Metric(labels={ |
|
"general_type": "timestamp", |
|
**common_labels |
|
}), |
|
"iter": |
|
self.general_metric_family.Metric(labels={ |
|
"general_type": "iteration_counter", |
|
**common_labels |
|
}), |
|
} |
|
if is_v1_model: |
|
self.all_metrics.update({ |
|
"num_ctx_tokens": |
|
self.model_type_metric_family.Metric(labels={ |
|
"v1_specific_metric": "total_context_tokens", |
|
**common_labels |
|
}), |
|
"num_gen_tokens": |
|
self.model_type_metric_family.Metric( |
|
labels={ |
|
"v1_specific_metric": "total_generation_tokens", |
|
**common_labels |
|
}), |
|
"empty_gen_slots": |
|
self.model_type_metric_family.Metric( |
|
labels={ |
|
"v1_specific_metric": "empty_generation_slots", |
|
**common_labels |
|
}), |
|
}) |
|
else: |
|
self.all_metrics.update({ |
|
"num_ctx_tokens": |
|
self.model_type_metric_family.Metric( |
|
labels={ |
|
"inflight_batcher_specific_metric": |
|
"total_context_tokens", |
|
**common_labels |
|
}), |
|
"num_gen_requests": |
|
self.model_type_metric_family.Metric( |
|
labels={ |
|
"inflight_batcher_specific_metric": |
|
"generation_requests", |
|
**common_labels |
|
}), |
|
"micro_batch_id": |
|
self.model_type_metric_family.Metric( |
|
labels={ |
|
"inflight_batcher_specific_metric": "micro_batch_id", |
|
**common_labels |
|
}), |
|
"num_paused_requests": |
|
self.model_type_metric_family.Metric( |
|
labels={ |
|
"inflight_batcher_specific_metric": "paused_requests", |
|
**common_labels |
|
}), |
|
}) |
|
|
|
def initialize(self, args): |
|
"""`initialize` is called only once when the model is being loaded. |
|
Implementing `initialize` function is optional. This function allows |
|
the model to initialize any state associated with this model. |
|
|
|
Parameters |
|
---------- |
|
args : dict |
|
Both keys and values are strings. The dictionary keys and values are: |
|
* model_config: A JSON string containing the model configuration |
|
* model_instance_kind: A string containing model instance kind |
|
* model_instance_device_id: A string containing model instance device ID |
|
* model_repository: Model repository path |
|
* model_version: Model version |
|
* model_name: Model name |
|
""" |
|
model_config = json.loads(args['model_config']) |
|
gpt_model_path = get_parameter(model_config, "gpt_model_path") |
|
if get_parameter(model_config, "enable_trt_overlap", bool): |
|
raise pb_utils.TritonModelException( |
|
f"enable_trt_overlap=true is not supported.") |
|
self.exclude_input_from_output = get_parameter( |
|
model_config, "exclude_input_in_output", bool) |
|
executor_config = self.get_executor_config(model_config) |
|
self.executor = trtllm.Executor(gpt_model_path, |
|
trtllm.ModelType.DECODER_ONLY, |
|
executor_config) |
|
self.decoupled = pb_utils.using_decoupled_model_transaction_policy( |
|
model_config) |
|
self.cancellation_check_period_ms = get_parameter( |
|
model_config, "cancellation_check_period_ms", int) or 100 |
|
self.stats_check_period_ms = get_parameter( |
|
model_config, "stats_check_period_ms", int) or 100 |
|
|
|
if not self.decoupled: |
|
raise pb_utils.TritonModelException( |
|
"Please enable decoupled transaction policy in the model configuration to serve this model" |
|
) |
|
|
|
self.create_metrics(args["model_name"], |
|
args["model_version"], |
|
is_v1_model=executor_config.batching_type == |
|
trtllm.BatchingType.STATIC) |
|
self.triton_id_to_req_id = {} |
|
self.req_id_to_response_sender = {} |
|
self.lock = Lock() |
|
self.running = False |
|
self.awaiter_thread = Thread(target=self.awaiter_loop) |
|
self.cancellation_thread = Thread(target=self.cancellation_loop) |
|
self.metrics_thread = Thread(target=self.metrics_loop) |
|
if self.executor.can_enqueue_requests(): |
|
self.running = True |
|
self.awaiter_thread.start() |
|
self.cancellation_thread.start() |
|
self.metrics_thread.start() |
|
else: |
|
|
|
self.executor.shutdown() |
|
|
|
def handle_stop_request(self, triton_id, response_sender): |
|
if triton_id is None or triton_id == "": |
|
response_sender.send( |
|
pb_utils.InferenceResponse(error=pb_utils.TritonError( |
|
"A request id must be provided for request cancellation")), |
|
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) |
|
return |
|
|
|
if triton_id in self.triton_id_to_req_id: |
|
req_id = self.triton_id_to_req_id[triton_id] |
|
self.executor.cancel_request(req_id) |
|
|
|
response_sender.send( |
|
pb_utils.InferenceResponse(), |
|
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) |
|
|
|
def execute(self, requests): |
|
"""`execute` must be implemented in every Python model. `execute` |
|
function receives a list of pb_utils.InferenceRequest as the only |
|
argument. This function is called when an inference is requested |
|
for this model. |
|
|
|
Parameters |
|
---------- |
|
requests : list |
|
A list of pb_utils.InferenceRequest |
|
|
|
Returns |
|
------- |
|
list |
|
A list of pb_utils.InferenceResponse. The length of this list must |
|
be the same as `requests` |
|
""" |
|
if not self.executor.can_enqueue_requests(): |
|
return |
|
|
|
|
|
triton_requests = [] |
|
executor_requests = [] |
|
for request in requests: |
|
response_sender = request.get_response_sender() |
|
if get_input_scalar_by_name(request, 'stop'): |
|
self.handle_stop_request(request.request_id(), response_sender) |
|
else: |
|
try: |
|
converted = convert_request(request, |
|
self.exclude_input_from_output, |
|
self.decoupled) |
|
except Exception as e: |
|
response_sender.send( |
|
pb_utils.InferenceResponse(error=pb_utils.TritonError( |
|
f"An error occurred when processing the input values for request id {request.request_id()}, the error was '{e}'" |
|
)), |
|
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) |
|
else: |
|
triton_requests.append(request) |
|
executor_requests.append(converted) |
|
|
|
with self.lock: |
|
request_ids = self.executor.enqueue_requests(executor_requests) |
|
for req_id, request in zip(request_ids, triton_requests): |
|
triton_id = request.request_id() |
|
self.req_id_to_response_sender[ |
|
req_id] = triton_id, request.get_response_sender() |
|
self.triton_id_to_req_id[triton_id] = req_id |
|
return None |
|
|
|
def awaiter_loop(self): |
|
"""Gets responses from executor and returns the results.""" |
|
while self.running: |
|
for response in self.executor.await_responses( |
|
timeout=datetime.timedelta(milliseconds=1)): |
|
req_id = response.request_id |
|
with self.lock: |
|
if req_id not in self.req_id_to_response_sender: |
|
continue |
|
triton_id, response_sender = self.req_id_to_response_sender[ |
|
req_id] |
|
|
|
triton_response, is_final = convert_response(response) |
|
response_sender.send( |
|
triton_response, |
|
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL |
|
if is_final else 0) |
|
|
|
if is_final: |
|
with self.lock: |
|
del self.triton_id_to_req_id[triton_id] |
|
del self.req_id_to_response_sender[req_id] |
|
|
|
del response_sender |
|
|
|
def cancellation_loop(self): |
|
"""Checks if any pending requests have been cancelled.""" |
|
while self.running: |
|
time.sleep(self.cancellation_check_period_ms / 1000.0) |
|
with self.lock: |
|
for req_id, (triton_id, response_sender |
|
) in self.req_id_to_response_sender.items(): |
|
if response_sender.is_cancelled(): |
|
self.executor.cancel_request(req_id) |
|
|
|
del response_sender |
|
|
|
def metrics_loop(self): |
|
"""Updates triton metrics using stats from the executor.""" |
|
while self.running: |
|
time.sleep(self.stats_check_period_ms / 1000.0) |
|
for stat in self.executor.get_latest_iteration_stats(): |
|
try: |
|
for key, metric in self.all_metrics.items(): |
|
value = None |
|
if hasattr(stat, key): |
|
value = getattr(stat, key) |
|
elif stat.kv_cache_stats is not None and hasattr( |
|
stat.kv_cache_stats, key): |
|
value = getattr(stat.kv_cache_stats, key) |
|
elif stat.static_batching_stats is not None and hasattr( |
|
stat.static_batching_stats, key): |
|
value = getattr(stat.static_batching_stats, key) |
|
elif stat.inflight_batching_stats is not None and hasattr( |
|
stat.inflight_batching_stats, key): |
|
value = getattr(stat.inflight_batching_stats, key) |
|
if value is not None: |
|
if key == "timestamp": |
|
value = convert_timestamp_to_seconds(value) |
|
metric.set(value) |
|
else: |
|
pb_utils.Logger.log_warn( |
|
f"Metric \"{key}\" not found.") |
|
except Exception as e: |
|
pb_utils.Logger.log_warn( |
|
f"Error while processing metrics: {e}") |
|
|
|
def finalize(self): |
|
"""`finalize` is called only once when the model is being unloaded. |
|
Implementing `finalize` function is optional. This function allows |
|
the model to perform any necessary clean ups before exit. |
|
""" |
|
if self.executor.can_enqueue_requests(): |
|
self.running = False |
|
self.awaiter_thread.join() |
|
self.cancellation_thread.join() |
|
self.metrics_thread.join() |
|
self.executor.shutdown() |
|
|