project_charles / respond_to_prompt_async.py
sohojoe's picture
wip - stop response if user interrupts
4df5a8a
raw
history blame
5.5 kB
from asyncio import Queue, TaskGroup
import asyncio
from contextlib import asynccontextmanager
import ray
from chat_service import ChatService
# from local_speaker_service import LocalSpeakerService
from text_to_speech_service import TextToSpeechService
from response_state_manager import ResponseStateManager
from ffmpeg_converter import FFMpegConverter
from agent_response import AgentResponse
import json
class RespondToPromptAsync:
def __init__(
self,
response_state_manager:ResponseStateManager,
audio_output_queue):
voice_id="2OviOUQc1JsQRQgNkVBj"
self.llm_sentence_queue = Queue(maxsize=100)
self.speech_chunk_queue = Queue(maxsize=100)
self.voice_id = voice_id
self.audio_output_queue = audio_output_queue
self.response_state_manager = response_state_manager
self.sentence_queues = []
self.sentence_tasks = []
# self.ffmpeg_converter = FFMpegConverter.remote(audio_output_queue)
async def prompt_to_llm(self, prompt:str, messages:[str]):
chat_service = ChatService()
async with TaskGroup() as tg:
agent_response = AgentResponse(prompt)
async for text, is_complete_sentance in chat_service.get_responses_as_sentances_async(messages):
if chat_service.ignore_sentence(text):
is_complete_sentance = False
if not is_complete_sentance:
agent_response['llm_preview'] = text
self.response_state_manager.set_llm_preview(text)
continue
agent_response['llm_preview'] = ''
agent_response['llm_sentence'] = text
agent_response['llm_sentences'].append(text)
self.response_state_manager.add_llm_response_and_clear_llm_preview(text)
print(f"{agent_response['llm_sentence']} id: {agent_response['llm_sentence_id']} from prompt: {agent_response['prompt']}")
sentence_response = agent_response.make_copy()
new_queue = Queue()
self.sentence_queues.append(new_queue)
task = tg.create_task(self.llm_sentence_to_speech(sentence_response, new_queue))
self.sentence_tasks.append(task)
agent_response['llm_sentence_id'] += 1
async def llm_sentence_to_speech(self, sentence_response, output_queue):
tts_service = TextToSpeechService(self.voice_id)
chunk_count = 0
async for chunk_response in tts_service.get_speech_chunks_async(sentence_response):
chunk_response = chunk_response.make_copy()
# await self.output_queue.put_async(chunk_response)
await output_queue.put(chunk_response)
chunk_response = {
'prompt': sentence_response['prompt'],
'llm_sentence_id': sentence_response['llm_sentence_id'],
'chunk_count': chunk_count,
}
chunk_id_json = json.dumps(chunk_response)
self.response_state_manager.add_tts_raw_chunk_id(chunk_id_json, sentence_response['llm_sentence_id'])
chunk_count += 1
async def speech_to_converter(self):
self.ffmpeg_converter = FFMpegConverter(self.audio_output_queue)
await self.ffmpeg_converter.start_process()
self.ffmpeg_converter_task = asyncio.create_task(self.ffmpeg_converter.run())
while True:
for i, task in enumerate(self.sentence_tasks):
# Skip this task/queue pair if task completed
queue = self.sentence_queues[i]
if task.done() and queue.empty():
continue
while not queue.empty():
chunk_response = await queue.get()
audio_chunk_ref = chunk_response['tts_raw_chunk_ref']
audio_chunk = ray.get(audio_chunk_ref)
await self.ffmpeg_converter.push_chunk(audio_chunk)
break
await asyncio.sleep(0.01)
async def run(self, prompt:str, messages:[str]):
self.task_group_tasks = []
async with TaskGroup() as tg: # Use asyncio's built-in TaskGroup
t1 = tg.create_task(self.prompt_to_llm(prompt, messages))
t2 = tg.create_task(self.speech_to_converter())
self.task_group_tasks.extend([t1, t2])
async def terminate(self):
# Cancel tasks
if self.task_group_tasks:
for task in self.task_group_tasks:
task.cancel()
for task in self.sentence_tasks:
task.cancel()
# Close FFmpeg converter actor
if self.ffmpeg_converter_task:
self.ffmpeg_converter_task.cancel()
await self.ffmpeg_converter.close()
# ray.kill(self.ffmpeg_converter)
# Flush all queues
# TODO re-enable to interupt when user speaks
# while not self.audio_output_queue.empty():
# await self.audio_output_queue.get_async()
# # await self.audio_output_queue.get_async(block=False)
while not self.llm_sentence_queue.empty():
self.llm_sentence_queue.get_nowait()
while not self.speech_chunk_queue.empty():
self.speech_chunk_queue.get_nowait()
for sentence_queue in self.sentence_queues:
while not sentence_queue.empty():
sentence_queue.get_nowait()