File size: 4,770 Bytes
a461509
 
c259114
a461509
 
 
 
 
 
c259114
 
a461509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c259114
a461509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import Literal
import gradio as gr
from gradio_webrtc import WebRTC, StreamHandler, AdditionalOutputs, get_twilio_turn_credentials
from numpy import ndarray
import sphn
import websockets.sync.client
import numpy as np


rtc_configuration = get_twilio_turn_credentials()

class MoshiHandler(StreamHandler):

    def __init__(self,
                 url: str,
                 expected_layout: Literal['mono', 'stereo'] = "mono",
                 output_sample_rate: int = 24000,
                 output_frame_size: int = 480) -> None:
        self.url = url
        proto, without_proto = self.url.split('://', 1)
        if proto in ['ws', 'http']:
            proto = "ws"
        elif proto in ['wss', 'https']:
            proto = "wss"
        
        self._generator = None
        self.output_chunk_size = 1920
        self.ws = None
        self.ws_url = f"{proto}://{without_proto}/api/chat"
        self.stream_reader = sphn.OpusStreamReader(output_sample_rate)
        self.stream_writer = sphn.OpusStreamWriter(output_sample_rate)
        self.all_output_data = None
        super().__init__(expected_layout, output_sample_rate, output_frame_size,
                         input_sample_rate=24000)
    
    def receive(self, frame: tuple[int, ndarray]) -> None:
        if not self.ws:
            self.ws = websockets.sync.client.connect(self.ws_url)
        _, array = frame
        array = array.squeeze().astype(np.float32) / 32768.0
        self.stream_writer.append_pcm(array)
        bytes = b"\x01" + self.stream_writer.read_bytes()
        self.ws.send(bytes)

    def generator(self):
        for message in self.ws:
            if len(message) == 0:
                yield None
            kind = message[0]
            if kind == 1:
                payload = message[1:]
                self.stream_reader.append_bytes(payload)
                pcm = self.stream_reader.read_pcm()
                if self.all_output_data is None:
                    self.all_output_data = pcm
                else:
                    self.all_output_data = np.concatenate((self.all_output_data, pcm))
                while self.all_output_data.shape[-1] >= self.output_chunk_size:
                    yield (self.output_sample_rate, self.all_output_data[: self.output_chunk_size].reshape(1, -1))
                    self.all_output_data = np.array(self.all_output_data[self.output_chunk_size :])
            elif kind == 2:
                payload = message[1:]
                yield AdditionalOutputs(payload.decode())

    
    def emit(self) -> tuple[int, ndarray] | None:
        if not self.ws:
            return
        if not self._generator:
            self._generator = self.generator()
        try:
            return next(self._generator)
        except StopIteration:
            self.reset()
    
    def reset(self) -> None:
        self._generator = None
        self.all_output_data = None
    
    def copy(self) -> StreamHandler:
        return MoshiHandler(self.url,
                            self.expected_layout,
                            self.output_sample_rate, self.output_frame_size)

    def shutdown(self) -> None:
        if self.ws:
            self.ws.close()



with gr.Blocks() as demo:
        gr.HTML(
        """
        <div style='text-align: center'>
            <h1>
                Talk To Moshi (Powered by WebRTC ⚡️)
            </h1>
            <p>
                Each conversation is limited to 90 seconds. Once the time limit is up you can rejoin the conversation.
            </p>
        </div>
        """
        )
        response = gr.State(value="")
        chatbot = gr.Chatbot(type="messages", value=[])
        webrtc = WebRTC(label="Conversation", modality="audio", mode="send-receive", rtc_configuration=rtc_configuration)
        webrtc.stream(MoshiHandler("https://freddyaboulton-moshi-server.hf.space"),
                      inputs=[webrtc, chatbot], outputs=[webrtc], time_limit=90)
        
        def on_text(state, response):
            print("response", response)
            return state + response
        
        def add_text(chat_history, response):
            if len(chat_history) == 0:
                chat_history.append({"role": "assistant", "content": response})
            else:
                chat_history[-1]["content"] = response
            return chat_history
        
        webrtc.on_additional_outputs(on_text,
                                     inputs=[response], outputs=response,
                                     queue=True,
                                     show_progress="hidden")
        response.change(add_text, inputs=[chatbot, response], outputs=chatbot,
                        queue=True, show_progress="hidden")

demo.launch()