freddyaboulton HF staff commited on
Commit
01a49c3
·
1 Parent(s): f20d058
Files changed (2) hide show
  1. app.py +33 -26
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import base64
2
  import io
 
3
  import tempfile
4
  import time
5
  import traceback
6
- from dataclasses import dataclass, field
7
  from queue import Queue
8
- from threading import Thread, Event
9
 
10
  import gradio as gr
11
  import librosa
@@ -14,16 +15,13 @@ import requests
14
  from gradio_webrtc import StreamHandler, WebRTC
15
  from huggingface_hub import snapshot_download
16
  from pydub import AudioSegment
17
- import librosa
18
- from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
19
- import tempfile
20
 
21
  # from server import serve
22
  from utils.vad import VadOptions, collect_chunks, get_speech_timestamps
23
 
24
-
25
- from server import serve
26
-
27
  repo_id = "gpt-omni/mini-omni"
28
  snapshot_download(repo_id, local_dir="./checkpoint", revision="main")
29
 
@@ -36,8 +34,20 @@ thread.start()
36
 
37
  API_URL = "http://0.0.0.0:60808/chat"
38
 
 
 
 
 
 
39
 
40
- #API_URL = "https://freddyaboulton-omni-backend.hf.space/chat"
 
 
 
 
 
 
 
41
 
42
  # recording parameters
43
  IN_CHANNELS = 1
@@ -89,7 +99,8 @@ def warm_up():
89
  print(f"warm up done, time_cost: {tcost:.3f} s")
90
 
91
 
92
- warm_up()
 
93
 
94
  @dataclass
95
  class AppState:
@@ -97,27 +108,26 @@ class AppState:
97
  sampling_rate: int = 0
98
  pause_detected: bool = False
99
  started_talking: bool = False
100
- responding: bool = False
101
  stopped: bool = False
102
  buffer: np.ndarray | None = None
103
 
104
 
105
-
106
  def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
107
  """Take in the stream, determine if a pause happened"""
108
  duration = len(audio) / sampling_rate
109
-
110
  dur_vad, _, _ = run_vad(audio, sampling_rate)
111
 
112
  if duration >= 0.60:
113
  if dur_vad > 0.2 and not state.started_talking:
114
  print("started talking")
115
- state.started_talking = True
116
  if state.started_talking:
117
  if state.stream is None:
118
  state.stream = audio
119
  else:
120
- state.stream = np.concatenate((state.stream, audio))
121
  state.buffer = None
122
  if dur_vad < 0.1 and state.started_talking:
123
  segment = AudioSegment(
@@ -135,7 +145,6 @@ def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> b
135
 
136
 
137
  def speaking(audio_bytes: str):
138
-
139
  base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8")
140
  files = {"audio": base64_encoded}
141
  byte_buffer = b""
@@ -167,7 +176,6 @@ def speaking(audio_bytes: str):
167
  raise gr.Error(f"Error during audio streaming: {e}")
168
 
169
 
170
-
171
  def process_audio(audio: tuple, state: AppState) -> None:
172
  frame_rate, array = audio
173
  array = np.squeeze(array)
@@ -185,7 +193,7 @@ def process_audio(audio: tuple, state: AppState) -> None:
185
  def response(state: AppState):
186
  if not state.pause_detected and not state.started_talking:
187
  return None
188
-
189
  audio_buffer = io.BytesIO()
190
  segment = AudioSegment(
191
  state.stream.tobytes(),
@@ -194,14 +202,16 @@ def response(state: AppState):
194
  channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
195
  )
196
  segment.export(audio_buffer, format="wav")
197
-
198
  for numpy_array in speaking(audio_buffer.getvalue()):
199
- yield (OUT_RATE, numpy_array, "mono")
200
 
201
 
202
  class OmniHandler(StreamHandler):
203
  def __init__(self) -> None:
204
- super().__init__(expected_layout="mono", output_sample_rate=OUT_RATE, output_frame_size=480)
 
 
205
  self.chunk_queue = Queue()
206
  self.state = AppState()
207
  self.generator = None
@@ -213,7 +223,7 @@ class OmniHandler(StreamHandler):
213
  process_audio(frame, self.state)
214
  if self.state.pause_detected:
215
  self.chunk_queue.put(True)
216
-
217
  def reset(self):
218
  self.generator = None
219
  self.state = AppState()
@@ -225,10 +235,9 @@ class OmniHandler(StreamHandler):
225
  self.state.responding = True
226
  self.generator = response(self.state)
227
  try:
228
- return next(self.generator)
229
  except StopIteration:
230
  self.reset()
231
-
232
 
233
 
234
  with gr.Blocks() as demo:
@@ -250,6 +259,4 @@ with gr.Blocks() as demo:
250
  audio.stream(fn=OmniHandler(), inputs=[audio], outputs=[audio], time_limit=300)
251
 
252
 
253
-
254
-
255
  demo.launch()
 
1
  import base64
2
  import io
3
+ import os
4
  import tempfile
5
  import time
6
  import traceback
7
+ from dataclasses import dataclass
8
  from queue import Queue
9
+ from threading import Thread
10
 
11
  import gradio as gr
12
  import librosa
 
15
  from gradio_webrtc import StreamHandler, WebRTC
16
  from huggingface_hub import snapshot_download
17
  from pydub import AudioSegment
18
+ from twilio.rest import Client
19
+
20
+ from server import serve
21
 
22
  # from server import serve
23
  from utils.vad import VadOptions, collect_chunks, get_speech_timestamps
24
 
 
 
 
25
  repo_id = "gpt-omni/mini-omni"
26
  snapshot_download(repo_id, local_dir="./checkpoint", revision="main")
27
 
 
34
 
35
  API_URL = "http://0.0.0.0:60808/chat"
36
 
37
+ account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
38
+ auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
39
+
40
+ if account_sid and auth_token:
41
+ client = Client(account_sid, auth_token)
42
 
43
+ token = client.tokens.create()
44
+
45
+ rtc_configuration = {
46
+ "iceServers": token.ice_servers,
47
+ "iceTransportPolicy": "relay",
48
+ }
49
+ else:
50
+ rtc_configuration = None
51
 
52
  # recording parameters
53
  IN_CHANNELS = 1
 
99
  print(f"warm up done, time_cost: {tcost:.3f} s")
100
 
101
 
102
+ # warm_up()
103
+
104
 
105
  @dataclass
106
  class AppState:
 
108
  sampling_rate: int = 0
109
  pause_detected: bool = False
110
  started_talking: bool = False
111
+ responding: bool = False
112
  stopped: bool = False
113
  buffer: np.ndarray | None = None
114
 
115
 
 
116
  def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
117
  """Take in the stream, determine if a pause happened"""
118
  duration = len(audio) / sampling_rate
119
+
120
  dur_vad, _, _ = run_vad(audio, sampling_rate)
121
 
122
  if duration >= 0.60:
123
  if dur_vad > 0.2 and not state.started_talking:
124
  print("started talking")
125
+ state.started_talking = True
126
  if state.started_talking:
127
  if state.stream is None:
128
  state.stream = audio
129
  else:
130
+ state.stream = np.concatenate((state.stream, audio))
131
  state.buffer = None
132
  if dur_vad < 0.1 and state.started_talking:
133
  segment = AudioSegment(
 
145
 
146
 
147
  def speaking(audio_bytes: str):
 
148
  base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8")
149
  files = {"audio": base64_encoded}
150
  byte_buffer = b""
 
176
  raise gr.Error(f"Error during audio streaming: {e}")
177
 
178
 
 
179
  def process_audio(audio: tuple, state: AppState) -> None:
180
  frame_rate, array = audio
181
  array = np.squeeze(array)
 
193
  def response(state: AppState):
194
  if not state.pause_detected and not state.started_talking:
195
  return None
196
+
197
  audio_buffer = io.BytesIO()
198
  segment = AudioSegment(
199
  state.stream.tobytes(),
 
202
  channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
203
  )
204
  segment.export(audio_buffer, format="wav")
205
+
206
  for numpy_array in speaking(audio_buffer.getvalue()):
207
+ yield (OUT_RATE, numpy_array, "mono")
208
 
209
 
210
  class OmniHandler(StreamHandler):
211
  def __init__(self) -> None:
212
+ super().__init__(
213
+ expected_layout="mono", output_sample_rate=OUT_RATE, output_frame_size=480
214
+ )
215
  self.chunk_queue = Queue()
216
  self.state = AppState()
217
  self.generator = None
 
223
  process_audio(frame, self.state)
224
  if self.state.pause_detected:
225
  self.chunk_queue.put(True)
226
+
227
  def reset(self):
228
  self.generator = None
229
  self.state = AppState()
 
235
  self.state.responding = True
236
  self.generator = response(self.state)
237
  try:
238
+ return next(self.generator)
239
  except StopIteration:
240
  self.reset()
 
241
 
242
 
243
  with gr.Blocks() as demo:
 
259
  audio.stream(fn=OmniHandler(), inputs=[audio], outputs=[audio], time_limit=300)
260
 
261
 
 
 
262
  demo.launch()
requirements.txt CHANGED
@@ -12,4 +12,5 @@ fastapi==0.112.4
12
  librosa==0.10.2.post1
13
  flask==3.0.3
14
  fire
15
- https://gradio-builds.s3.us-east-1.amazonaws.com/webrtc/08/gradio_webrtc-0.0.5-py3-none-any.whl
 
 
12
  librosa==0.10.2.post1
13
  flask==3.0.3
14
  fire
15
+ https://gradio-builds.s3.us-east-1.amazonaws.com/webrtc/08/gradio_webrtc-0.0.5-py3-none-any.whl
16
+ twilio