File size: 5,384 Bytes
967aebb
c72e80d
 
 
2d00549
c72e80d
2d00549
c72e80d
c4bb76f
3abafc4
c72e80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3abafc4
c72e80d
 
2d00549
 
 
 
 
 
 
 
 
 
c72e80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3abafc4
c72e80d
3abafc4
c72e80d
2d00549
3abafc4
2d00549
3abafc4
2d00549
 
3abafc4
2d00549
3abafc4
2d00549
3abafc4
c72e80d
 
3abafc4
 
c72e80d
3abafc4
 
 
 
 
 
 
 
 
 
 
 
 
 
2d00549
c72e80d
f6f039f
c72e80d
 
 
 
 
 
 
 
 
3abafc4
 
 
 
2d00549
3abafc4
 
 
 
2d00549
3abafc4
 
 
c4bb76f
3abafc4
 
 
 
 
 
 
 
 
 
 
 
 
 
c72e80d
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from typing import Dict, Any, List, Generator
import torch
import os
import logging
from s2s_pipeline import main, prepare_all_args, get_default_arguments, setup_logger, initialize_queues_and_events, build_pipeline
import numpy as np
from queue import Queue, Empty
import threading
import base64
import uuid

class EndpointHandler:
    def __init__(self, path=""):
        (
            self.module_kwargs,
            self.socket_receiver_kwargs,
            self.socket_sender_kwargs,
            self.vad_handler_kwargs,
            self.whisper_stt_handler_kwargs,
            self.paraformer_stt_handler_kwargs,
            self.language_model_handler_kwargs,
            self.mlx_language_model_handler_kwargs,
            self.parler_tts_handler_kwargs,
            self.melo_tts_handler_kwargs,
            self.chat_tts_handler_kwargs,
        ) = get_default_arguments(mode='none', log_level='DEBUG')
        setup_logger(self.module_kwargs.log_level)

        prepare_all_args(
            self.module_kwargs,
            self.whisper_stt_handler_kwargs,
            self.paraformer_stt_handler_kwargs,
            self.language_model_handler_kwargs,
            self.mlx_language_model_handler_kwargs,
            self.parler_tts_handler_kwargs,
            self.melo_tts_handler_kwargs,
            self.chat_tts_handler_kwargs,
        )

        self.queues_and_events = initialize_queues_and_events()

        self.pipeline_manager = build_pipeline(
            self.module_kwargs,
            self.socket_receiver_kwargs,
            self.socket_sender_kwargs,
            self.vad_handler_kwargs,
            self.whisper_stt_handler_kwargs,
            self.paraformer_stt_handler_kwargs,
            self.language_model_handler_kwargs,
            self.mlx_language_model_handler_kwargs,
            self.parler_tts_handler_kwargs,
            self.melo_tts_handler_kwargs,
            self.chat_tts_handler_kwargs,
            self.queues_and_events,
        )

        self.pipeline_manager.start()

        # Add a new queue for collecting the final output
        self.final_output_queue = Queue()
        self.sessions = {}  # Store session information

    def _collect_output(self, session_id):
        while True:
            try:
                output = self.queues_and_events['send_audio_chunks_queue'].get(timeout=2)
                if isinstance(output, (str, bytes)) and output in (b"END", "END"):
                    self.sessions[session_id]['status'] = 'completed'
                    break
                elif isinstance(output, np.ndarray):
                    self.sessions[session_id]['chunks'].append(output.tobytes())
                else:
                    self.sessions[session_id]['chunks'].append(output)
            except Empty:
                self.sessions[session_id]['status'] = 'completed'
                break

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        request_type = data.get("request_type", "start")
        
        if request_type == "start":
            return self._handle_start_request(data)
        elif request_type == "continue":
            return self._handle_continue_request(data)
        else:
            raise ValueError(f"Unsupported request type: {request_type}")

    def _handle_start_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
        session_id = str(uuid.uuid4())
        self.sessions[session_id] = {
            'status': 'processing',
            'chunks': [],
            'last_sent_index': 0
        }

        input_type = data.get("input_type", "text")
        input_data = data.get("inputs", "")

        if input_type == "speech":
            audio_array = np.frombuffer(input_data, dtype=np.int16)
            self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes())
        elif input_type == "text":
            self.queues_and_events['text_prompt_queue'].put(input_data)
        else:
            raise ValueError(f"Unsupported input type: {input_type}")

        # Start output collection in a separate thread
        threading.Thread(target=self._collect_output, args=(session_id,)).start()

        return {"session_id": session_id, "status": "processing"}

    def _handle_continue_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
        session_id = data.get("session_id")
        if not session_id or session_id not in self.sessions:
            raise ValueError("Invalid or missing session_id")

        session = self.sessions[session_id]
        chunks_to_send = session['chunks'][session['last_sent_index']:]
        session['last_sent_index'] = len(session['chunks'])

        if chunks_to_send:
            combined_audio = b''.join(chunks_to_send)
            base64_audio = base64.b64encode(combined_audio).decode('utf-8')
            return {
                "session_id": session_id,
                "status": session['status'],
                "output": base64_audio
            }
        else:
            return {
                "session_id": session_id,
                "status": session['status'],
                "output": None
            }

    def cleanup(self):
        # Stop the pipeline
        self.pipeline_manager.stop()
        
        # Stop the output collector thread
        self.queues_and_events['send_audio_chunks_queue'].put(b"END")
        self.output_collector_thread.join()