makaveli10 commited on
Commit
28cd485
1 Parent(s): f97a5dd

add: dolphin model inference; chatml prompt format

Browse files
Files changed (1) hide show
  1. llm_service.py +35 -10
llm_service.py CHANGED
@@ -7,6 +7,7 @@ logging.basicConfig(level = logging.INFO)
7
  import numpy as np
8
  import torch
9
  from transformers import AutoTokenizer
 
10
 
11
  import tensorrt_llm
12
  from tensorrt_llm.logger import logger
@@ -177,6 +178,14 @@ class TensorRTLLMEngine:
177
  formatted_prompt += f"Alice: {user_prompt}\nBob:{llm_response}\n"
178
  return f"{formatted_prompt}Alice: {prompt}\nBob:"
179
 
 
 
 
 
 
 
 
 
180
  def run(
181
  self,
182
  model_path,
@@ -185,7 +194,7 @@ class TensorRTLLMEngine:
185
  llm_queue=None,
186
  audio_queue=None,
187
  input_text=None,
188
- max_output_len=40,
189
  max_attention_window_size=4096,
190
  num_beams=1,
191
  streaming=False,
@@ -226,12 +235,13 @@ class TensorRTLLMEngine:
226
  print(f"History: {conversation_history}")
227
  continue
228
 
229
- input_text=[self.format_prompt_qa(prompt, conversation_history[transcription_output["uid"]])]
230
- # print(f"Formatted prompt with history...:\n{input_text}")
 
231
 
232
  self.eos = transcription_output["eos"]
233
 
234
- logging.info(f"[LLM INFO:] WhisperLive prompt: {prompt}, eos: {self.eos}")
235
  batch_input_ids = self.parse_input(
236
  input_text=input_text,
237
  add_special_tokens=True,
@@ -240,6 +250,7 @@ class TensorRTLLMEngine:
240
  )
241
 
242
  input_lengths = [x.size(0) for x in batch_input_ids]
 
243
  with torch.no_grad():
244
  outputs = self.runner.generate(
245
  batch_input_ids,
@@ -276,9 +287,6 @@ class TensorRTLLMEngine:
276
  if output is None:
277
  break
278
 
279
- if output is not None:
280
- if "Instruct" in output[0]:
281
- output[0] = output[0].split("Instruct")[0]
282
  # Interrupted by transcription queue
283
  if output is None:
284
  logging.info(f"[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
@@ -298,12 +306,10 @@ class TensorRTLLMEngine:
298
  sequence_lengths,
299
  transcription_queue
300
  )
301
- if output is not None:
302
- if "Instruct" in output[0]:
303
- output[0] = output[0].split("Instruct")[0]
304
 
305
  # if self.eos:
306
  if output is not None:
 
307
  self.last_output = output
308
  self.last_prompt = prompt
309
  llm_queue.put({"uid": transcription_output["uid"], "llm_output": output, "eos": self.eos})
@@ -316,7 +322,26 @@ class TensorRTLLMEngine:
316
  self.last_prompt = None
317
  self.last_output = None
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
 
 
 
320
  if __name__=="__main__":
321
  llm = TensorRTLLMEngine()
322
  llm.initialize_model(
 
7
  import numpy as np
8
  import torch
9
  from transformers import AutoTokenizer
10
+ import re
11
 
12
  import tensorrt_llm
13
  from tensorrt_llm.logger import logger
 
178
  formatted_prompt += f"Alice: {user_prompt}\nBob:{llm_response}\n"
179
  return f"{formatted_prompt}Alice: {prompt}\nBob:"
180
 
181
+ def format_prompt_chatml(self, prompt, conversation_history, system_prompt=""):
182
+ formatted_prompt = ("<|im_start|>system\n" + system_prompt + "<|im_end|>\n")
183
+ for user_prompt, llm_response in conversation_history:
184
+ formatted_prompt += f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
185
+ formatted_prompt += f"<|im_start|>assistant\n{llm_response}<|im_end|>\n"
186
+ formatted_prompt += f"<|im_start|>user\n{prompt}<|im_end|>\n"
187
+ return formatted_prompt
188
+
189
  def run(
190
  self,
191
  model_path,
 
194
  llm_queue=None,
195
  audio_queue=None,
196
  input_text=None,
197
+ max_output_len=50,
198
  max_attention_window_size=4096,
199
  num_beams=1,
200
  streaming=False,
 
235
  print(f"History: {conversation_history}")
236
  continue
237
 
238
+ # input_text=[self.format_prompt_qa(prompt, conversation_history[transcription_output["uid"]])]
239
+ input_text=[self.format_prompt_chatml(prompt, conversation_history[transcription_output["uid"]], system_prompt="You are Dolphin, a helpful AI assistant")]
240
+ logging.info(f"[LLM INFO:] Formatted prompt with history...:\n{input_text}")
241
 
242
  self.eos = transcription_output["eos"]
243
 
244
+ # logging.info(f"[LLM INFO:] WhisperLive prompt: {prompt}, eos: {self.eos}")
245
  batch_input_ids = self.parse_input(
246
  input_text=input_text,
247
  add_special_tokens=True,
 
250
  )
251
 
252
  input_lengths = [x.size(0) for x in batch_input_ids]
253
+ print(f"[LLM INFO:] Input lengths: {input_lengths} / 1024")
254
  with torch.no_grad():
255
  outputs = self.runner.generate(
256
  batch_input_ids,
 
287
  if output is None:
288
  break
289
 
 
 
 
290
  # Interrupted by transcription queue
291
  if output is None:
292
  logging.info(f"[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
 
306
  sequence_lengths,
307
  transcription_queue
308
  )
 
 
 
309
 
310
  # if self.eos:
311
  if output is not None:
312
+ output[0] = clean_llm_output(output[0])
313
  self.last_output = output
314
  self.last_prompt = prompt
315
  llm_queue.put({"uid": transcription_output["uid"], "llm_output": output, "eos": self.eos})
 
322
  self.last_prompt = None
323
  self.last_output = None
324
 
325
+ def clean_llm_output(output):
326
+ output = output.replace("\n\nDolphin\n\n", "")
327
+ output = output.replace("\nDolphin\n\n", "")
328
+
329
+ if not output.endswith('.') and not output.endswith('?') and not output.endswith('!'):
330
+ last_punct = output.rfind('.')
331
+ last_q = output.rfind('?')
332
+ if last_q > last_punct:
333
+ last_punct = last_q
334
+
335
+ last_ex = output.rfind('!')
336
+ if last_ex > last_punct:
337
+ last_punct = last_ex
338
+
339
+ if last_punct > 0:
340
+ output = output[:last_punct+1]
341
 
342
+ return output
343
+
344
+
345
  if __name__=="__main__":
346
  llm = TensorRTLLMEngine()
347
  llm.initialize_model(