Spaces:
Running on CPU Upgrade

akhaliq HF staff commited on
Commit
fbaae9e
·
1 Parent(s): 71c8124

add gemini voice

Browse files
Files changed (3) hide show
  1. app.py +4 -1
  2. app_gemini_voice.py +209 -0
  3. requirements.txt +5 -0
app.py CHANGED
@@ -26,9 +26,12 @@ from app_together import demo as demo_together
26
  from app_xai import demo as demo_grok
27
  from app_showui import demo as demo_showui
28
  from app_omini import demo as demo_omini
 
 
29
 
30
  # Create mapping of providers to their demos
31
  PROVIDERS = {
 
32
  "Gemini": demo_gemini,
33
  "Grok": demo_grok,
34
  "Cohere": demo_cohere,
@@ -58,7 +61,7 @@ PROVIDERS = {
58
 
59
  demo = get_app(
60
  models=list(PROVIDERS.keys()),
61
- default_model="Gemini",
62
  src=PROVIDERS,
63
  dropdown_label="Select Provider"
64
  )
 
26
  from app_xai import demo as demo_grok
27
  from app_showui import demo as demo_showui
28
  from app_omini import demo as demo_omini
29
+ from app_gemini_voice import demo as demo_gemini_voice
30
+
31
 
32
  # Create mapping of providers to their demos
33
  PROVIDERS = {
34
+ "Gemini Voice": demo_gemini_voice(),
35
  "Gemini": demo_gemini,
36
  "Grok": demo_grok,
37
  "Cohere": demo_cohere,
 
61
 
62
  demo = get_app(
63
  models=list(PROVIDERS.keys()),
64
+ default_model="Gemini Voice",
65
  src=PROVIDERS,
66
  dropdown_label="Select Provider"
67
  )
app_gemini_voice.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_webrtc import WebRTC, StreamHandler, get_twilio_turn_credentials
3
+ import websockets.sync.client
4
+ import numpy as np
5
+ import json
6
+ import base64
7
+ import os
8
+ from dotenv import load_dotenv
9
+
10
+ class GeminiConfig:
11
+ def __init__(self):
12
+ load_dotenv()
13
+ self.api_key = self._get_api_key()
14
+ self.host = 'generativelanguage.googleapis.com'
15
+ self.model = 'models/gemini-2.0-flash-exp'
16
+ self.ws_url = f'wss://{self.host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}'
17
+
18
+ def _get_api_key(self):
19
+ api_key = os.getenv('GOOGLE_API_KEY')
20
+ if not api_key:
21
+ raise ValueError("GOOGLE_API_KEY not found in environment variables. Please set it in your .env file.")
22
+ return api_key
23
+
24
+ class AudioProcessor:
25
+ @staticmethod
26
+ def encode_audio(data, sample_rate):
27
+ encoded = base64.b64encode(data.tobytes()).decode('UTF-8')
28
+ return {
29
+ 'realtimeInput': {
30
+ 'mediaChunks': [{
31
+ 'mimeType': f'audio/pcm;rate={sample_rate}',
32
+ 'data': encoded,
33
+ }],
34
+ },
35
+ }
36
+
37
+ @staticmethod
38
+ def process_audio_response(data):
39
+ audio_data = base64.b64decode(data)
40
+ return np.frombuffer(audio_data, dtype=np.int16)
41
+
42
+ class GeminiHandler(StreamHandler):
43
+ def __init__(self,
44
+ expected_layout="mono",
45
+ output_sample_rate=24000,
46
+ output_frame_size=480) -> None:
47
+ super().__init__(expected_layout, output_sample_rate, output_frame_size,
48
+ input_sample_rate=24000)
49
+ self.config = GeminiConfig()
50
+ self.ws = None
51
+ self.all_output_data = None
52
+ self.audio_processor = AudioProcessor()
53
+
54
+ def copy(self):
55
+ return GeminiHandler(
56
+ expected_layout=self.expected_layout,
57
+ output_sample_rate=self.output_sample_rate,
58
+ output_frame_size=self.output_frame_size
59
+ )
60
+
61
+ def _initialize_websocket(self):
62
+ try:
63
+ self.ws = websockets.sync.client.connect(
64
+ self.config.ws_url,
65
+ timeout=30
66
+ )
67
+ initial_request = {
68
+ 'setup': {
69
+ 'model': self.config.model,
70
+ }
71
+ }
72
+ self.ws.send(json.dumps(initial_request))
73
+ setup_response = json.loads(self.ws.recv())
74
+ print(f"Setup response: {setup_response}")
75
+ except websockets.exceptions.WebSocketException as e:
76
+ print(f"WebSocket connection failed: {str(e)}")
77
+ self.ws = None
78
+ except Exception as e:
79
+ print(f"Setup failed: {str(e)}")
80
+ self.ws = None
81
+
82
+ def receive(self, frame: tuple[int, np.ndarray]) -> None:
83
+ try:
84
+ if not self.ws:
85
+ self._initialize_websocket()
86
+
87
+ _, array = frame
88
+ array = array.squeeze()
89
+ audio_message = self.audio_processor.encode_audio(array, self.output_sample_rate)
90
+ self.ws.send(json.dumps(audio_message))
91
+ except Exception as e:
92
+ print(f"Error in receive: {str(e)}")
93
+ if self.ws:
94
+ self.ws.close()
95
+ self.ws = None
96
+
97
+ def _process_server_content(self, content):
98
+ for part in content.get('parts', []):
99
+ data = part.get('inlineData', {}).get('data', '')
100
+ if data:
101
+ audio_array = self.audio_processor.process_audio_response(data)
102
+ if self.all_output_data is None:
103
+ self.all_output_data = audio_array
104
+ else:
105
+ self.all_output_data = np.concatenate((self.all_output_data, audio_array))
106
+
107
+ while self.all_output_data.shape[-1] >= self.output_frame_size:
108
+ yield (self.output_sample_rate,
109
+ self.all_output_data[:self.output_frame_size].reshape(1, -1))
110
+ self.all_output_data = self.all_output_data[self.output_frame_size:]
111
+
112
+ def generator(self):
113
+ while True:
114
+ if not self.ws:
115
+ print("WebSocket not connected")
116
+ yield None
117
+ continue
118
+
119
+ try:
120
+ message = self.ws.recv(timeout=5)
121
+ msg = json.loads(message)
122
+
123
+ if 'serverContent' in msg:
124
+ content = msg['serverContent'].get('modelTurn', {})
125
+ yield from self._process_server_content(content)
126
+ except TimeoutError:
127
+ print("Timeout waiting for server response")
128
+ yield None
129
+ except Exception as e:
130
+ print(f"Error in generator: {str(e)}")
131
+ yield None
132
+
133
+ def emit(self) -> tuple[int, np.ndarray] | None:
134
+ if not self.ws:
135
+ return None
136
+ if not hasattr(self, '_generator'):
137
+ self._generator = self.generator()
138
+ try:
139
+ return next(self._generator)
140
+ except StopIteration:
141
+ self.reset()
142
+ return None
143
+
144
+ def reset(self) -> None:
145
+ if hasattr(self, '_generator'):
146
+ delattr(self, '_generator')
147
+ self.all_output_data = None
148
+
149
+ def shutdown(self) -> None:
150
+ if self.ws:
151
+ self.ws.close()
152
+
153
+ def check_connection(self):
154
+ try:
155
+ if not self.ws or self.ws.closed:
156
+ self._initialize_websocket()
157
+ return True
158
+ except Exception as e:
159
+ print(f"Connection check failed: {str(e)}")
160
+ return False
161
+
162
+ class GeminiVoiceChat:
163
+ def __init__(self):
164
+ load_dotenv()
165
+ self.demo = self._create_interface()
166
+
167
+ def _create_interface(self):
168
+ with gr.Blocks() as demo:
169
+ gr.HTML("""
170
+ <div style='text-align: center'>
171
+ <h1>Gemini 2.0 Voice Chat</h1>
172
+ <p>Speak with Gemini using real-time audio streaming</p>
173
+ </div>
174
+ """)
175
+
176
+ webrtc = WebRTC(
177
+ label="Conversation",
178
+ modality="audio",
179
+ mode="send-receive",
180
+ rtc_configuration=get_twilio_turn_credentials()
181
+ )
182
+
183
+ webrtc.stream(
184
+ GeminiHandler(),
185
+ inputs=[webrtc],
186
+ outputs=[webrtc],
187
+ time_limit=90,
188
+ concurrency_limit=10
189
+ )
190
+ return demo
191
+
192
+ def launch(self):
193
+ self.demo.launch(
194
+ server_name="0.0.0.0",
195
+ server_port=int(os.environ.get("PORT", 7860)),
196
+ share=True,
197
+ ssl_verify=False,
198
+ ssl_keyfile=None,
199
+ ssl_certfile=None
200
+ )
201
+
202
+ def demo():
203
+ voice_chat = GeminiVoiceChat()
204
+ return voice_chat.demo
205
+
206
+ # This allows both direct running and importing
207
+ if __name__ == "__main__":
208
+ app = GeminiVoiceChat()
209
+ app.launch()
requirements.txt CHANGED
@@ -400,3 +400,8 @@ playai-gradio @ git+https://github.com/AK391/playai-gradio.git
400
  lumaai-gradio @ git+https://github.com/AK391/lumaai-gradio.git
401
 
402
  cohere-gradio @ git+https://github.com/AK391/cohere-gradio.git
 
 
 
 
 
 
400
  lumaai-gradio @ git+https://github.com/AK391/lumaai-gradio.git
401
 
402
  cohere-gradio @ git+https://github.com/AK391/cohere-gradio.git
403
+
404
+ gradio_webrtc==0.0.23
405
+ librosa
406
+ python-dotenv
407
+ twilio