makaveli10 commited on
Commit
5d7959c
1 Parent(s): 795e5f6

add support for phi

Browse files
Files changed (1) hide show
  1. main.py +30 -7
main.py CHANGED
@@ -16,8 +16,11 @@ def parse_arguments():
16
  parser = argparse.ArgumentParser()
17
  parser.add_argument('--whisper_tensorrt_path',
18
  type=str,
19
- default=None,
20
  help='Whisper TensorRT model path')
 
 
 
21
  parser.add_argument('--mistral_tensorrt_path',
22
  type=str,
23
  default=None,
@@ -26,6 +29,17 @@ def parse_arguments():
26
  type=str,
27
  default="teknium/OpenHermes-2.5-Mistral-7B",
28
  help='Mistral TensorRT model path')
 
 
 
 
 
 
 
 
 
 
 
29
  return parser.parse_args()
30
 
31
 
@@ -36,10 +50,17 @@ if __name__ == "__main__":
36
  import sys
37
  sys.exit(0)
38
 
39
- if not args.mistral_tensorrt_path or not args.mistral_tokenizer_path:
40
- raise ValueError("Please provide mistral_tensorrt_path and mistral_tokenizer_path to run the pipeline.")
41
- import sys
42
- sys.exit(0)
 
 
 
 
 
 
 
43
 
44
  multiprocessing.set_start_method('spawn')
45
 
@@ -70,8 +91,10 @@ if __name__ == "__main__":
70
  llm_process = multiprocessing.Process(
71
  target=llm_provider.run,
72
  args=(
73
- args.mistral_tensorrt_path,
74
- args.mistral_tokenizer_path,
 
 
75
  transcription_queue,
76
  llm_queue,
77
  )
 
16
  parser = argparse.ArgumentParser()
17
  parser.add_argument('--whisper_tensorrt_path',
18
  type=str,
19
+ default="/root/TensorRT-LLM/examples/whisper/whisper_small_en",
20
  help='Whisper TensorRT model path')
21
+ parser.add_argument('--mistral',
22
+ action="store_true",
23
+ help='Mistral')
24
  parser.add_argument('--mistral_tensorrt_path',
25
  type=str,
26
  default=None,
 
29
  type=str,
30
  default="teknium/OpenHermes-2.5-Mistral-7B",
31
  help='Mistral TensorRT model path')
32
+ parser.add_argument('--phi',
33
+ action="store_true",
34
+ help='Phi')
35
+ parser.add_argument('--phi_tensorrt_path',
36
+ type=str,
37
+ default="/root/TensorRT-LLM/examples/phi/phi_engine",
38
+ help='Phi TensorRT model path')
39
+ parser.add_argument('--phi_tokenizer_path',
40
+ type=str,
41
+ default="/root/TensorRT-LLM/examples/phi/phi-2",
42
+ help='Phi Tokenizer path')
43
  return parser.parse_args()
44
 
45
 
 
50
  import sys
51
  sys.exit(0)
52
 
53
+ if args.mistral:
54
+ if not args.mistral_tensorrt_path or not args.mistral_tokenizer_path:
55
+ raise ValueError("Please provide mistral_tensorrt_path and mistral_tokenizer_path to run the pipeline.")
56
+ import sys
57
+ sys.exit(0)
58
+
59
+ if args.phi:
60
+ if not args.phi_tensorrt_path or not args.phi_tokenizer_path:
61
+ raise ValueError("Please provide phi_tensorrt_path and phi_tokenizer_path to run the pipeline.")
62
+ import sys
63
+ sys.exit(0)
64
 
65
  multiprocessing.set_start_method('spawn')
66
 
 
91
  llm_process = multiprocessing.Process(
92
  target=llm_provider.run,
93
  args=(
94
+ # args.mistral_tensorrt_path,
95
+ # args.mistral_tokenizer_path,
96
+ args.phi_tensorrt_path,
97
+ args.phi_tokenizer_path,
98
  transcription_queue,
99
  llm_queue,
100
  )