File size: 5,499 Bytes
c490c32
 
 
 
 
 
 
 
90a9891
7a1cd88
c490c32
 
 
 
 
 
90a9891
c490c32
 
 
 
 
 
90a9891
c490c32
 
7a1cd88
c490c32
2bb91de
c490c32
 
 
2bb91de
 
 
 
 
 
90a9891
2bb91de
 
 
 
90a9891
2bb91de
 
 
 
 
 
 
c490c32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90a9891
c490c32
 
 
7a1cd88
 
 
c490c32
 
 
 
 
814feb3
 
c490c32
 
 
 
7a1cd88
c490c32
 
 
 
2bb91de
 
c490c32
2bb91de
 
 
 
 
 
 
 
 
 
 
 
 
7a1cd88
 
 
 
2bb91de
 
4df5a8a
 
 
 
2bb91de
4df5a8a
2bb91de
4df5a8a
2bb91de
 
4df5a8a
2bb91de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from asyncio import Queue, TaskGroup
import asyncio
from contextlib import asynccontextmanager

import ray
from chat_service import ChatService
# from local_speaker_service import LocalSpeakerService
from text_to_speech_service import TextToSpeechService
from response_state_manager import ResponseStateManager
from ffmpeg_converter import FFMpegConverter
from agent_response import AgentResponse
import json

class RespondToPromptAsync:
    def __init__(
            self, 
            response_state_manager:ResponseStateManager, 
            audio_output_queue):
        voice_id="2OviOUQc1JsQRQgNkVBj"
        self.llm_sentence_queue = Queue(maxsize=100)
        self.speech_chunk_queue = Queue(maxsize=100)
        self.voice_id = voice_id
        self.audio_output_queue = audio_output_queue
        self.response_state_manager = response_state_manager
        self.sentence_queues = []
        self.sentence_tasks = []    
        # self.ffmpeg_converter = FFMpegConverter.remote(audio_output_queue)

    async def prompt_to_llm(self, prompt:str, messages:[str]):
        chat_service = ChatService()

        async with TaskGroup() as tg:
            agent_response = AgentResponse(prompt)
            async for text, is_complete_sentance in chat_service.get_responses_as_sentances_async(messages):
                if chat_service.ignore_sentence(text):
                    is_complete_sentance = False
                if not is_complete_sentance:
                    agent_response['llm_preview'] = text
                    self.response_state_manager.set_llm_preview(text)
                    continue
                agent_response['llm_preview'] = ''
                agent_response['llm_sentence'] = text
                agent_response['llm_sentences'].append(text)
                self.response_state_manager.add_llm_response_and_clear_llm_preview(text)
                print(f"{agent_response['llm_sentence']} id: {agent_response['llm_sentence_id']} from prompt: {agent_response['prompt']}")
                sentence_response = agent_response.make_copy()
                new_queue = Queue()
                self.sentence_queues.append(new_queue)
                task = tg.create_task(self.llm_sentence_to_speech(sentence_response, new_queue))
                self.sentence_tasks.append(task)
                agent_response['llm_sentence_id'] += 1


    async def llm_sentence_to_speech(self, sentence_response, output_queue):
        tts_service = TextToSpeechService(self.voice_id)
        
        chunk_count = 0
        async for chunk_response in tts_service.get_speech_chunks_async(sentence_response):
            chunk_response = chunk_response.make_copy()
            # await self.output_queue.put_async(chunk_response)
            await output_queue.put(chunk_response)
            chunk_response = {
                'prompt': sentence_response['prompt'],
                'llm_sentence_id': sentence_response['llm_sentence_id'],
                'chunk_count': chunk_count,
            }
            chunk_id_json = json.dumps(chunk_response)
            self.response_state_manager.add_tts_raw_chunk_id(chunk_id_json, sentence_response['llm_sentence_id'])
            chunk_count += 1

    async def speech_to_converter(self):
        self.ffmpeg_converter = FFMpegConverter(self.audio_output_queue)
        await self.ffmpeg_converter.start_process()
        self.ffmpeg_converter_task = asyncio.create_task(self.ffmpeg_converter.run())
        
        while True:
            for i, task in enumerate(self.sentence_tasks):
                # Skip this task/queue pair if task completed
                queue = self.sentence_queues[i]
                if task.done() and queue.empty():
                    continue 
                while not queue.empty():
                    chunk_response = await queue.get()
                    audio_chunk_ref = chunk_response['tts_raw_chunk_ref']
                    audio_chunk = ray.get(audio_chunk_ref)
                    await self.ffmpeg_converter.push_chunk(audio_chunk)
                break

            await asyncio.sleep(0.01) 

    async def run(self, prompt:str, messages:[str]):
        self.task_group_tasks = []
        async with TaskGroup() as tg:  # Use asyncio's built-in TaskGroup
            t1 = tg.create_task(self.prompt_to_llm(prompt, messages))
            t2 = tg.create_task(self.speech_to_converter())
            self.task_group_tasks.extend([t1, t2])

    async def terminate(self):
        # Cancel tasks
        if self.task_group_tasks:
            for task in self.task_group_tasks:
                task.cancel()
        for task in self.sentence_tasks:
            task.cancel()

        # Close FFmpeg converter actor
        if self.ffmpeg_converter_task:
            self.ffmpeg_converter_task.cancel()
            await self.ffmpeg_converter.close()
        # ray.kill(self.ffmpeg_converter)

        # Flush all queues
        # TODO re-enable to interupt when user speaks
        # while not self.audio_output_queue.empty():
        #     await self.audio_output_queue.get_async()
        #     # await self.audio_output_queue.get_async(block=False)
        while not self.llm_sentence_queue.empty():
            self.llm_sentence_queue.get_nowait()
        while not self.speech_chunk_queue.empty():
            self.speech_chunk_queue.get_nowait()
        for sentence_queue in self.sentence_queues:
            while not sentence_queue.empty():
                sentence_queue.get_nowait()