File size: 5,384 Bytes
967aebb c72e80d 2d00549 c72e80d 2d00549 c72e80d c4bb76f 3abafc4 c72e80d 3abafc4 c72e80d 2d00549 c72e80d 3abafc4 c72e80d 3abafc4 c72e80d 2d00549 3abafc4 2d00549 3abafc4 2d00549 3abafc4 2d00549 3abafc4 2d00549 3abafc4 c72e80d 3abafc4 c72e80d 3abafc4 2d00549 c72e80d f6f039f c72e80d 3abafc4 2d00549 3abafc4 2d00549 3abafc4 c4bb76f 3abafc4 c72e80d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from typing import Dict, Any, List, Generator
import torch
import os
import logging
from s2s_pipeline import main, prepare_all_args, get_default_arguments, setup_logger, initialize_queues_and_events, build_pipeline
import numpy as np
from queue import Queue, Empty
import threading
import base64
import uuid
class EndpointHandler:
def __init__(self, path=""):
(
self.module_kwargs,
self.socket_receiver_kwargs,
self.socket_sender_kwargs,
self.vad_handler_kwargs,
self.whisper_stt_handler_kwargs,
self.paraformer_stt_handler_kwargs,
self.language_model_handler_kwargs,
self.mlx_language_model_handler_kwargs,
self.parler_tts_handler_kwargs,
self.melo_tts_handler_kwargs,
self.chat_tts_handler_kwargs,
) = get_default_arguments(mode='none', log_level='DEBUG')
setup_logger(self.module_kwargs.log_level)
prepare_all_args(
self.module_kwargs,
self.whisper_stt_handler_kwargs,
self.paraformer_stt_handler_kwargs,
self.language_model_handler_kwargs,
self.mlx_language_model_handler_kwargs,
self.parler_tts_handler_kwargs,
self.melo_tts_handler_kwargs,
self.chat_tts_handler_kwargs,
)
self.queues_and_events = initialize_queues_and_events()
self.pipeline_manager = build_pipeline(
self.module_kwargs,
self.socket_receiver_kwargs,
self.socket_sender_kwargs,
self.vad_handler_kwargs,
self.whisper_stt_handler_kwargs,
self.paraformer_stt_handler_kwargs,
self.language_model_handler_kwargs,
self.mlx_language_model_handler_kwargs,
self.parler_tts_handler_kwargs,
self.melo_tts_handler_kwargs,
self.chat_tts_handler_kwargs,
self.queues_and_events,
)
self.pipeline_manager.start()
# Add a new queue for collecting the final output
self.final_output_queue = Queue()
self.sessions = {} # Store session information
def _collect_output(self, session_id):
while True:
try:
output = self.queues_and_events['send_audio_chunks_queue'].get(timeout=2)
if isinstance(output, (str, bytes)) and output in (b"END", "END"):
self.sessions[session_id]['status'] = 'completed'
break
elif isinstance(output, np.ndarray):
self.sessions[session_id]['chunks'].append(output.tobytes())
else:
self.sessions[session_id]['chunks'].append(output)
except Empty:
self.sessions[session_id]['status'] = 'completed'
break
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
request_type = data.get("request_type", "start")
if request_type == "start":
return self._handle_start_request(data)
elif request_type == "continue":
return self._handle_continue_request(data)
else:
raise ValueError(f"Unsupported request type: {request_type}")
def _handle_start_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
session_id = str(uuid.uuid4())
self.sessions[session_id] = {
'status': 'processing',
'chunks': [],
'last_sent_index': 0
}
input_type = data.get("input_type", "text")
input_data = data.get("inputs", "")
if input_type == "speech":
audio_array = np.frombuffer(input_data, dtype=np.int16)
self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes())
elif input_type == "text":
self.queues_and_events['text_prompt_queue'].put(input_data)
else:
raise ValueError(f"Unsupported input type: {input_type}")
# Start output collection in a separate thread
threading.Thread(target=self._collect_output, args=(session_id,)).start()
return {"session_id": session_id, "status": "processing"}
def _handle_continue_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
session_id = data.get("session_id")
if not session_id or session_id not in self.sessions:
raise ValueError("Invalid or missing session_id")
session = self.sessions[session_id]
chunks_to_send = session['chunks'][session['last_sent_index']:]
session['last_sent_index'] = len(session['chunks'])
if chunks_to_send:
combined_audio = b''.join(chunks_to_send)
base64_audio = base64.b64encode(combined_audio).decode('utf-8')
return {
"session_id": session_id,
"status": session['status'],
"output": base64_audio
}
else:
return {
"session_id": session_id,
"status": session['status'],
"output": None
}
def cleanup(self):
# Stop the pipeline
self.pipeline_manager.stop()
# Stop the output collector thread
self.queues_and_events['send_audio_chunks_queue'].put(b"END")
self.output_collector_thread.join() |