from asyncio import Queue, TaskGroup import asyncio from contextlib import asynccontextmanager import ray from chat_service import ChatService # from local_speaker_service import LocalSpeakerService from text_to_speech_service import TextToSpeechService from response_state_manager import ResponseStateManager from ffmpeg_converter import FFMpegConverter from agent_response import AgentResponse import json class RespondToPromptAsync: def __init__( self, response_state_manager:ResponseStateManager, audio_output_queue): voice_id="2OviOUQc1JsQRQgNkVBj" self.llm_sentence_queue = Queue(maxsize=100) self.speech_chunk_queue = Queue(maxsize=100) self.voice_id = voice_id self.audio_output_queue = audio_output_queue self.response_state_manager = response_state_manager self.sentence_queues = [] self.sentence_tasks = [] # self.ffmpeg_converter = FFMpegConverter.remote(audio_output_queue) async def prompt_to_llm(self, prompt:str, messages:[str]): chat_service = ChatService() async with TaskGroup() as tg: agent_response = AgentResponse(prompt) async for text, is_complete_sentance in chat_service.get_responses_as_sentances_async(messages): if chat_service.ignore_sentence(text): is_complete_sentance = False if not is_complete_sentance: agent_response['llm_preview'] = text self.response_state_manager.set_llm_preview(text) continue agent_response['llm_preview'] = '' agent_response['llm_sentence'] = text agent_response['llm_sentences'].append(text) self.response_state_manager.add_llm_response_and_clear_llm_preview(text) print(f"{agent_response['llm_sentence']} id: {agent_response['llm_sentence_id']} from prompt: {agent_response['prompt']}") sentence_response = agent_response.make_copy() new_queue = Queue() self.sentence_queues.append(new_queue) task = tg.create_task(self.llm_sentence_to_speech(sentence_response, new_queue)) self.sentence_tasks.append(task) agent_response['llm_sentence_id'] += 1 async def llm_sentence_to_speech(self, sentence_response, output_queue): tts_service = TextToSpeechService(self.voice_id) chunk_count = 0 async for chunk_response in tts_service.get_speech_chunks_async(sentence_response): chunk_response = chunk_response.make_copy() # await self.output_queue.put_async(chunk_response) await output_queue.put(chunk_response) chunk_response = { 'prompt': sentence_response['prompt'], 'llm_sentence_id': sentence_response['llm_sentence_id'], 'chunk_count': chunk_count, } chunk_id_json = json.dumps(chunk_response) self.response_state_manager.add_tts_raw_chunk_id(chunk_id_json, sentence_response['llm_sentence_id']) chunk_count += 1 async def speech_to_converter(self): self.ffmpeg_converter = FFMpegConverter(self.audio_output_queue) await self.ffmpeg_converter.start_process() self.ffmpeg_converter_task = asyncio.create_task(self.ffmpeg_converter.run()) while True: for i, task in enumerate(self.sentence_tasks): # Skip this task/queue pair if task completed queue = self.sentence_queues[i] if task.done() and queue.empty(): continue while not queue.empty(): chunk_response = await queue.get() audio_chunk_ref = chunk_response['tts_raw_chunk_ref'] audio_chunk = ray.get(audio_chunk_ref) await self.ffmpeg_converter.push_chunk(audio_chunk) break await asyncio.sleep(0.01) async def run(self, prompt:str, messages:[str]): self.task_group_tasks = [] async with TaskGroup() as tg: # Use asyncio's built-in TaskGroup t1 = tg.create_task(self.prompt_to_llm(prompt, messages)) t2 = tg.create_task(self.speech_to_converter()) self.task_group_tasks.extend([t1, t2]) async def terminate(self): # Cancel tasks if self.task_group_tasks: for task in self.task_group_tasks: task.cancel() for task in self.sentence_tasks: task.cancel() # Close FFmpeg converter actor if self.ffmpeg_converter_task: self.ffmpeg_converter_task.cancel() await self.ffmpeg_converter.close() # ray.kill(self.ffmpeg_converter) # Flush all queues while not self.llm_sentence_queue.empty(): await self.llm_sentence_queue.get() while not self.speech_chunk_queue.empty(): await self.speech_chunk_queue.get() for sentence_queue in self.sentence_queues: while not sentence_queue.empty(): await sentence_queue.get()