freddyaboulton HF staff commited on
Commit
e9e89e1
1 Parent(s): ce780f3
Files changed (4) hide show
  1. .gitignore +1 -0
  2. app.py +207 -0
  3. openai-logo.svg +1 -0
  4. requirements.txt +5 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import os
4
+ import time
5
+ from threading import Event, Thread
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import openai
10
+ from dotenv import load_dotenv
11
+ from gradio_webrtc import (
12
+ AdditionalOutputs,
13
+ StreamHandler,
14
+ WebRTC,
15
+ get_twilio_turn_credentials,
16
+ )
17
+ from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
18
+ from pydub import AudioSegment
19
+
20
+ load_dotenv()
21
+
22
+ SAMPLE_RATE = 24000
23
+
24
+
25
+ def encode_audio(sample_rate, data):
26
+ segment = AudioSegment(
27
+ data.tobytes(),
28
+ frame_rate=sample_rate,
29
+ sample_width=data.dtype.itemsize,
30
+ channels=1,
31
+ )
32
+ pcm_audio = (
33
+ segment.set_frame_rate(SAMPLE_RATE).set_channels(1).set_sample_width(2).raw_data
34
+ )
35
+ return base64.b64encode(pcm_audio).decode("utf-8")
36
+
37
+
38
+ class OpenAIHandler(StreamHandler):
39
+ def __init__(
40
+ self,
41
+ expected_layout="mono",
42
+ output_sample_rate=SAMPLE_RATE,
43
+ output_frame_size=480,
44
+ ) -> None:
45
+ super().__init__(
46
+ expected_layout,
47
+ output_sample_rate,
48
+ output_frame_size,
49
+ input_sample_rate=SAMPLE_RATE,
50
+ )
51
+ self.connection = None
52
+ self.all_output_data = None
53
+ self.args_set = Event()
54
+ self.quit = Event()
55
+ self.connected = Event()
56
+ self.thread = None
57
+ self._generator = None
58
+
59
+ def copy(self):
60
+ return OpenAIHandler(
61
+ expected_layout=self.expected_layout,
62
+ output_sample_rate=self.output_sample_rate,
63
+ output_frame_size=self.output_frame_size,
64
+ )
65
+
66
+ def _initialize_connection(self, api_key: str):
67
+ """Connect to realtime API. Run forever in separate thread to keep connection open."""
68
+ self.client = openai.Client(api_key=api_key)
69
+ with self.client.beta.realtime.connect(
70
+ model="gpt-4o-realtime-preview-2024-10-01"
71
+ ) as conn:
72
+ conn.session.update(session={"turn_detection": {"type": "server_vad"}})
73
+ self.connection = conn
74
+ self.connected.set()
75
+ while not self.quit.is_set():
76
+ time.sleep(0.25)
77
+
78
+ async def fetch_args(
79
+ self,
80
+ ):
81
+ if self.channel:
82
+ self.channel.send("tick")
83
+
84
+ def set_args(self, args):
85
+ super().set_args(args)
86
+ self.args_set.set()
87
+
88
+ def receive(self, frame: tuple[int, np.ndarray]) -> None:
89
+ if not self.channel:
90
+ return
91
+ if not self.connection:
92
+ asyncio.run_coroutine_threadsafe(self.fetch_args(), self.loop)
93
+ self.args_set.wait()
94
+ self.thread = Thread(
95
+ target=self._initialize_connection, args=(self.latest_args[-1],)
96
+ )
97
+ self.thread.start()
98
+ self.connected.wait()
99
+ try:
100
+ assert self.connection, "Connection not initialized"
101
+ sample_rate, array = frame
102
+ array = array.squeeze()
103
+ audio_message = encode_audio(sample_rate, array)
104
+ self.connection.input_audio_buffer.append(audio=audio_message)
105
+ except Exception as e:
106
+ # print traceback
107
+ print(f"Error in receive: {str(e)}")
108
+ import traceback
109
+
110
+ traceback.print_exc()
111
+
112
+ def generator(self):
113
+ while True:
114
+ if not self.connection:
115
+ yield None
116
+ continue
117
+ for event in self.connection:
118
+ if event.type == "response.audio_transcript.done":
119
+ yield AdditionalOutputs(event)
120
+ if event.type == "response.audio.delta":
121
+ yield (
122
+ self.output_sample_rate,
123
+ np.frombuffer(
124
+ base64.b64decode(event.delta), dtype=np.int16
125
+ ).reshape(1, -1),
126
+ )
127
+
128
+ def emit(self) -> tuple[int, np.ndarray] | None:
129
+ if not self.connection:
130
+ return None
131
+ if not self._generator:
132
+ self._generator = self.generator()
133
+ try:
134
+ return next(self._generator)
135
+ except StopIteration:
136
+ self._generator = self.generator()
137
+ return None
138
+
139
+ def shutdown(self) -> None:
140
+ if self.connection:
141
+ self.connection.close()
142
+ self.quit.set()
143
+ if self.thread:
144
+ self.thread.join(timeout=5)
145
+
146
+
147
+ def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
148
+ chatbot.append({"role": "assistant", "content": response.transcript})
149
+ return chatbot
150
+
151
+
152
+ with gr.Blocks() as demo:
153
+ gr.HTML("""
154
+ <div style='display: flex; align-items: center; justify-content: center; gap: 20px'>
155
+ <div style="background-color: var(--block-background-fill); border-radius: 8px">
156
+ <img src="/gradio_api/file=openai-logo.svg" style="width: 100px; height: 100px;">
157
+ </div>
158
+ <div>
159
+ <h1>OpenAI Realtime Voice Chat</h1>
160
+ <p>Speak with OpenAI's latest using real-time audio streaming api.</p>
161
+ <p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href==https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>⚡️</p>
162
+ <p>Get an API key from <a href="https://platform.openai.com/">OpenAI</a>.</p>
163
+ </div>
164
+ </div>
165
+ """)
166
+
167
+ with gr.Row(visible=True) as api_key_row:
168
+ api_key = gr.Textbox(
169
+ label="OpenAI API Key",
170
+ placeholder="Enter your OpenAI API Key",
171
+ value=os.getenv("OPENAI_API_KEY", ""),
172
+ type="password",
173
+ )
174
+ with gr.Row(visible=False) as row:
175
+ with gr.Column(scale=1):
176
+ webrtc = WebRTC(
177
+ label="Conversation",
178
+ modality="audio",
179
+ mode="send-receive",
180
+ rtc_configuration=get_twilio_turn_credentials(),
181
+ icon="openai-logo.svg",
182
+ )
183
+ with gr.Column(scale=5):
184
+ chatbot = gr.Chatbot(label="Conversation", value=[], type="messages")
185
+ webrtc.stream(
186
+ OpenAIHandler(),
187
+ inputs=[webrtc, api_key],
188
+ outputs=[webrtc],
189
+ time_limit=90,
190
+ concurrency_limit=2,
191
+ )
192
+ webrtc.on_additional_outputs(
193
+ update_chatbot,
194
+ inputs=[chatbot],
195
+ outputs=[chatbot],
196
+ show_progress="hidden",
197
+ queue=True,
198
+ )
199
+ api_key.submit(
200
+ lambda: (gr.update(visible=False), gr.update(visible=True)),
201
+ None,
202
+ [api_key_row, row],
203
+ )
204
+
205
+
206
+ if __name__ == "__main__":
207
+ demo.launch(allowed_paths=["openai-logo.svg"])
openai-logo.svg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio_webrtc[vad]
2
+ openai
3
+ python-dotenv
4
+ twilio
5
+ numba==0.60.0