File size: 2,377 Bytes
b03c4ad
16388cf
 
 
b03c4ad
 
 
 
 
 
 
 
 
f97a5dd
16388cf
 
 
 
 
55616b9
b03c4ad
 
 
 
 
 
 
 
16388cf
 
b03c4ad
 
4a07920
 
 
 
16388cf
b03c4ad
16388cf
b03c4ad
16388cf
ea4fabf
16388cf
55616b9
16388cf
 
 
00c7470
 
 
16388cf
 
00c7470
bd36f24
00c7470
bd36f24
 
00c7470
 
 
 
16388cf
 
 
 
 
55616b9
b03c4ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import functools
import time
import logging
logging.basicConfig(level = logging.INFO)

from websockets.sync.server import serve
from whisperspeech.pipeline import Pipeline

class WhisperSpeechTTS:
    def __init__(self):
        pass
    
    def initialize_model(self):
        self.pipe = Pipeline(s2a_ref='collabora/whisperspeech:s2a-q4-tiny-en+pl.model', torch_compile=True)
        self.last_llm_response = None

    def run(self, host, port, audio_queue=None):
        # initialize and warmup model
        self.initialize_model()
        for i in range(3): self.pipe.generate("Hello, I am warming up.")

        with serve(
            functools.partial(self.start_whisperspeech_tts, audio_queue=audio_queue), 
            host, port
            ) as server:
            server.serve_forever()

    def start_whisperspeech_tts(self, websocket, audio_queue=None):
        self.eos = False
        self.output_audio = None

        while True:
            llm_response = audio_queue.get()
            if audio_queue.qsize() != 0:
                continue

            # check if this websocket exists
            try:
                websocket.ping()
            except Exception as e:
                del websocket
                audio_queue.put(llm_response)
                break
            
            llm_output = llm_response["llm_output"][0]
            self.eos = llm_response["eos"]

            def should_abort():
                if not audio_queue.empty(): raise TimeoutError()

            # only process if the output updated
            if self.last_llm_response != llm_output.strip():
                try:
                    start = time.time()
                    audio = self.pipe.generate(llm_output.strip(), step_callback=should_abort)
                    inference_time = time.time() - start
                    logging.info(f"[WhisperSpeech INFO:] TTS inference done in {inference_time} ms.\n\n")
                    self.output_audio = audio.cpu().numpy()
                    self.last_llm_response = llm_output.strip()
                except TimeoutError:
                    pass

            if self.eos and self.output_audio is not None:
                try:
                    websocket.send(self.output_audio.tobytes())
                except Exception as e:
                    logging.error(f"[WhisperSpeech ERROR:] Audio error: {e}")