|
import time |
|
import json |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import logging |
|
logging.basicConfig(level = logging.INFO) |
|
|
|
import numpy as np |
|
import torch |
|
from transformers import AutoTokenizer |
|
import re |
|
|
|
import tensorrt_llm |
|
from tensorrt_llm.logger import logger |
|
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner |
|
|
|
if PYTHON_BINDINGS: |
|
from tensorrt_llm.runtime import ModelRunnerCpp |
|
|
|
|
|
def read_model_name(engine_dir: str): |
|
engine_version = tensorrt_llm.runtime.engine.get_engine_version(engine_dir) |
|
|
|
with open(Path(engine_dir) / "config.json", 'r') as f: |
|
config = json.load(f) |
|
|
|
if engine_version is None: |
|
return config['builder_config']['name'] |
|
|
|
return config['pretrained_config']['architecture'] |
|
|
|
|
|
def throttle_generator(generator, stream_interval): |
|
for i, out in enumerate(generator): |
|
if not i % stream_interval: |
|
yield out |
|
|
|
if i % stream_interval: |
|
yield out |
|
|
|
|
|
def load_tokenizer(tokenizer_dir: Optional[str] = None, |
|
vocab_file: Optional[str] = None, |
|
model_name: str = 'gpt', |
|
tokenizer_type: Optional[str] = None): |
|
if vocab_file is None: |
|
use_fast = True |
|
if tokenizer_type is not None and tokenizer_type == "llama": |
|
use_fast = False |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, |
|
legacy=False, |
|
padding_side='left', |
|
truncation_side='left', |
|
trust_remote_code=True, |
|
tokenizer_type=tokenizer_type, |
|
use_fast=use_fast) |
|
else: |
|
|
|
assert model_name == 'gpt' |
|
tokenizer = T5Tokenizer(vocab_file=vocab_file, |
|
padding_side='left', |
|
truncation_side='left') |
|
|
|
if model_name == 'qwen': |
|
with open(Path(tokenizer_dir) / "generation_config.json") as f: |
|
gen_config = json.load(f) |
|
chat_format = gen_config['chat_format'] |
|
if chat_format == 'raw': |
|
pad_id = gen_config['pad_token_id'] |
|
end_id = gen_config['eos_token_id'] |
|
elif chat_format == 'chatml': |
|
pad_id = tokenizer.im_end_id |
|
end_id = tokenizer.im_end_id |
|
else: |
|
raise Exception(f"unknown chat format: {chat_format}") |
|
elif model_name == 'glm_10b': |
|
pad_id = tokenizer.pad_token_id |
|
end_id = tokenizer.eop_token_id |
|
else: |
|
if tokenizer.pad_token_id is None: |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
pad_id = tokenizer.pad_token_id |
|
end_id = tokenizer.eos_token_id |
|
|
|
return tokenizer, pad_id, end_id |
|
|
|
|
|
class TensorRTLLMEngine: |
|
def __init__(self): |
|
pass |
|
|
|
def initialize_model(self, engine_dir, tokenizer_dir): |
|
self.log_level = 'error' |
|
self.runtime_rank = tensorrt_llm.mpi_rank() |
|
logger.set_level(self.log_level) |
|
model_name = read_model_name(engine_dir) |
|
self.tokenizer, self.pad_id, self.end_id = load_tokenizer( |
|
tokenizer_dir=tokenizer_dir, |
|
vocab_file=None, |
|
model_name=model_name, |
|
tokenizer_type=None, |
|
) |
|
self.prompt_template = None |
|
self.runner_cls = ModelRunner |
|
self.runner_kwargs = dict(engine_dir=engine_dir, |
|
lora_dir=None, |
|
rank=self.runtime_rank, |
|
debug_mode=False, |
|
lora_ckpt_source='hf') |
|
self.runner = self.runner_cls.from_dir(**self.runner_kwargs) |
|
self.last_prompt = None |
|
self.last_output = None |
|
|
|
def parse_input( |
|
self, |
|
input_text=None, |
|
add_special_tokens=True, |
|
max_input_length=923, |
|
pad_id=None, |
|
): |
|
if self.pad_id is None: |
|
self.pad_id = self.tokenizer.pad_token_id |
|
|
|
batch_input_ids = [] |
|
for curr_text in input_text: |
|
if self.prompt_template is not None: |
|
curr_text = self.prompt_template.format(input_text=curr_text) |
|
input_ids = self.tokenizer.encode( |
|
curr_text, |
|
add_special_tokens=add_special_tokens, |
|
truncation=True, |
|
max_length=max_input_length |
|
) |
|
batch_input_ids.append(input_ids) |
|
|
|
batch_input_ids = [ |
|
torch.tensor(x, dtype=torch.int32) for x in batch_input_ids |
|
] |
|
return batch_input_ids |
|
|
|
def decode_tokens( |
|
self, |
|
output_ids, |
|
input_lengths, |
|
sequence_lengths, |
|
transcription_queue |
|
): |
|
batch_size, num_beams, _ = output_ids.size() |
|
for batch_idx in range(batch_size): |
|
if transcription_queue.qsize() != 0: |
|
return None |
|
|
|
inputs = output_ids[batch_idx][0][:input_lengths[batch_idx]].tolist() |
|
input_text = self.tokenizer.decode(inputs) |
|
output = [] |
|
for beam in range(num_beams): |
|
if transcription_queue.qsize() != 0: |
|
return None |
|
|
|
output_begin = input_lengths[batch_idx] |
|
output_end = sequence_lengths[batch_idx][beam] |
|
outputs = output_ids[batch_idx][beam][ |
|
output_begin:output_end].tolist() |
|
output_text = self.tokenizer.decode(outputs) |
|
output.append(output_text) |
|
return output |
|
|
|
def format_prompt_qa(self, prompt, conversation_history): |
|
formatted_prompt = "" |
|
for user_prompt, llm_response in conversation_history: |
|
formatted_prompt += f"Instruct: {user_prompt}\nOutput:{llm_response}\n" |
|
return f"{formatted_prompt}Instruct: {prompt}\nOutput:" |
|
|
|
def format_prompt_chat(self, prompt, conversation_history): |
|
formatted_prompt = "" |
|
for user_prompt, llm_response in conversation_history: |
|
formatted_prompt += f"Alice: {user_prompt}\nBob:{llm_response}\n" |
|
return f"{formatted_prompt}Alice: {prompt}\nBob:" |
|
|
|
def format_prompt_chatml(self, prompt, conversation_history, system_prompt=""): |
|
formatted_prompt = ("<|im_start|>system\n" + system_prompt + "<|im_end|>\n") |
|
for user_prompt, llm_response in conversation_history: |
|
formatted_prompt += f"<|im_start|>user\n{user_prompt}<|im_end|>\n" |
|
formatted_prompt += f"<|im_start|>assistant\n{llm_response}<|im_end|>\n" |
|
formatted_prompt += f"<|im_start|>user\n{prompt}<|im_end|>\n" |
|
return formatted_prompt |
|
|
|
def run( |
|
self, |
|
model_path, |
|
tokenizer_path, |
|
transcription_queue=None, |
|
llm_queue=None, |
|
audio_queue=None, |
|
input_text=None, |
|
max_output_len=50, |
|
max_attention_window_size=4096, |
|
num_beams=1, |
|
streaming=False, |
|
streaming_interval=4, |
|
debug=False, |
|
): |
|
self.initialize_model( |
|
model_path, |
|
tokenizer_path, |
|
) |
|
|
|
logging.info("[LLM INFO:] Loaded LLM TensorRT Engine.") |
|
|
|
conversation_history = {} |
|
|
|
while True: |
|
|
|
|
|
transcription_output = transcription_queue.get() |
|
if transcription_queue.qsize() != 0: |
|
continue |
|
|
|
if transcription_output["uid"] not in conversation_history: |
|
conversation_history[transcription_output["uid"]] = [] |
|
|
|
prompt = transcription_output['prompt'].strip() |
|
|
|
|
|
if self.last_prompt == prompt: |
|
if self.last_output is not None and transcription_output["eos"]: |
|
self.eos = transcription_output["eos"] |
|
llm_queue.put({ |
|
"uid": transcription_output["uid"], |
|
"llm_output": self.last_output, |
|
"eos": self.eos, |
|
"latency": self.infer_time |
|
}) |
|
audio_queue.put({"llm_output": self.last_output, "eos": self.eos}) |
|
conversation_history[transcription_output["uid"]].append( |
|
(transcription_output['prompt'].strip(), self.last_output[0].strip()) |
|
) |
|
continue |
|
|
|
|
|
input_text=[self.format_prompt_chatml(prompt, conversation_history[transcription_output["uid"]], system_prompt="You are Dolphin, a helpful AI assistant")] |
|
|
|
self.eos = transcription_output["eos"] |
|
|
|
batch_input_ids = self.parse_input( |
|
input_text=input_text, |
|
add_special_tokens=True, |
|
max_input_length=923, |
|
pad_id=None, |
|
) |
|
|
|
input_lengths = [x.size(0) for x in batch_input_ids] |
|
|
|
logging.info(f"[LLM INFO:] Running LLM Inference with WhisperLive prompt: {prompt}, eos: {self.eos}") |
|
start = time.time() |
|
with torch.no_grad(): |
|
outputs = self.runner.generate( |
|
batch_input_ids, |
|
max_new_tokens=max_output_len, |
|
max_attention_window_size=max_attention_window_size, |
|
end_id=self.end_id, |
|
pad_id=self.pad_id, |
|
temperature=1.0, |
|
top_k=1, |
|
top_p=0.0, |
|
num_beams=num_beams, |
|
length_penalty=1.0, |
|
repetition_penalty=1.0, |
|
stop_words_list=None, |
|
bad_words_list=None, |
|
lora_uids=None, |
|
prompt_table_path=None, |
|
prompt_tasks=None, |
|
streaming=streaming, |
|
output_sequence_lengths=True, |
|
return_dict=True) |
|
torch.cuda.synchronize() |
|
if streaming: |
|
for curr_outputs in throttle_generator(outputs, streaming_interval): |
|
output_ids = curr_outputs['output_ids'] |
|
sequence_lengths = curr_outputs['sequence_lengths'] |
|
output = self.decode_tokens( |
|
output_ids, |
|
input_lengths, |
|
sequence_lengths, |
|
transcription_queue |
|
) |
|
|
|
if output is None: |
|
break |
|
|
|
|
|
if output is None: |
|
continue |
|
else: |
|
output_ids = outputs['output_ids'] |
|
sequence_lengths = outputs['sequence_lengths'] |
|
context_logits = None |
|
generation_logits = None |
|
if self.runner.gather_context_logits: |
|
context_logits = outputs['context_logits'] |
|
if self.runner.gather_generation_logits: |
|
generation_logits = outputs['generation_logits'] |
|
output = self.decode_tokens( |
|
output_ids, |
|
input_lengths, |
|
sequence_lengths, |
|
transcription_queue |
|
) |
|
self.infer_time = time.time() - start |
|
|
|
|
|
if output is not None: |
|
output[0] = clean_llm_output(output[0]) |
|
self.last_output = output |
|
self.last_prompt = prompt |
|
llm_queue.put({ |
|
"uid": transcription_output["uid"], |
|
"llm_output": output, |
|
"eos": self.eos, |
|
"latency": self.infer_time |
|
}) |
|
audio_queue.put({"llm_output": output, "eos": self.eos}) |
|
logging.info(f"[LLM INFO:] Output: {output[0]}\nLLM inference done in {self.infer_time} ms\n\n") |
|
|
|
if self.eos: |
|
conversation_history[transcription_output["uid"]].append( |
|
(transcription_output['prompt'].strip(), output[0].strip()) |
|
) |
|
self.last_prompt = None |
|
self.last_output = None |
|
|
|
def clean_llm_output(output): |
|
output = output.replace("\n\nDolphin\n\n", "") |
|
output = output.replace("\nDolphin\n\n", "") |
|
output = output.replace("Dolphin: ", "") |
|
output = output.replace("Assistant: ", "") |
|
|
|
if not output.endswith('.') and not output.endswith('?') and not output.endswith('!'): |
|
last_punct = output.rfind('.') |
|
last_q = output.rfind('?') |
|
if last_q > last_punct: |
|
last_punct = last_q |
|
|
|
last_ex = output.rfind('!') |
|
if last_ex > last_punct: |
|
last_punct = last_ex |
|
|
|
if last_punct > 0: |
|
output = output[:last_punct+1] |
|
|
|
return output |
|
|