WhisperFusion / main.py
makaveli10
add history of conversation to prompt
e3f7cd8
raw
history blame
3.7 kB
import multiprocessing
import argparse
import threading
import ssl
import time
import sys
import functools
from multiprocessing import Process, Manager, Value, Queue
from whisper_live.trt_server import TranscriptionServer
from llm_service import TensorRTLLMEngine
from tts_service import WhisperSpeechTTS
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--whisper_tensorrt_path',
type=str,
default="/root/TensorRT-LLM/examples/whisper/whisper_small_en",
help='Whisper TensorRT model path')
parser.add_argument('--mistral',
action="store_true",
help='Mistral')
parser.add_argument('--mistral_tensorrt_path',
type=str,
default=None,
help='Mistral TensorRT model path')
parser.add_argument('--mistral_tokenizer_path',
type=str,
default="teknium/OpenHermes-2.5-Mistral-7B",
help='Mistral TensorRT model path')
parser.add_argument('--phi',
action="store_true",
help='Phi')
parser.add_argument('--phi_tensorrt_path',
type=str,
default="/root/TensorRT-LLM/examples/phi/phi_engine",
help='Phi TensorRT model path')
parser.add_argument('--phi_tokenizer_path',
type=str,
default="/root/TensorRT-LLM/examples/phi/phi-2",
help='Phi Tokenizer path')
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
if not args.whisper_tensorrt_path:
raise ValueError("Please provide whisper_tensorrt_path to run the pipeline.")
import sys
sys.exit(0)
if args.mistral:
if not args.mistral_tensorrt_path or not args.mistral_tokenizer_path:
raise ValueError("Please provide mistral_tensorrt_path and mistral_tokenizer_path to run the pipeline.")
import sys
sys.exit(0)
if args.phi:
if not args.phi_tensorrt_path or not args.phi_tokenizer_path:
raise ValueError("Please provide phi_tensorrt_path and phi_tokenizer_path to run the pipeline.")
import sys
sys.exit(0)
multiprocessing.set_start_method('spawn')
lock = multiprocessing.Lock()
manager = Manager()
shared_output = manager.list()
transcription_queue = Queue()
llm_queue = Queue()
audio_queue = Queue()
whisper_server = TranscriptionServer()
whisper_process = multiprocessing.Process(
target=whisper_server.run,
args=(
"0.0.0.0",
6006,
transcription_queue,
llm_queue,
args.whisper_tensorrt_path
)
)
whisper_process.start()
llm_provider = TensorRTLLMEngine()
# llm_provider = MistralTensorRTLLMProvider()
llm_process = multiprocessing.Process(
target=llm_provider.run,
args=(
# args.mistral_tensorrt_path,
# args.mistral_tokenizer_path,
args.phi_tensorrt_path,
args.phi_tokenizer_path,
transcription_queue,
llm_queue,
audio_queue,
)
)
llm_process.start()
# audio process
tts_runner = WhisperSpeechTTS()
tts_process = multiprocessing.Process(target=tts_runner.run, args=("0.0.0.0", 8888, audio_queue))
tts_process.start()
llm_process.join()
whisper_process.join()
tts_process.join()