freddyaboulton HF staff commited on
Commit
3dbb28b
·
verified ·
1 Parent(s): 3e16f8a

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. app.py +150 -0
  2. index.html +418 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import openai
9
+ from dotenv import load_dotenv
10
+ from fastapi import FastAPI
11
+ from fastapi.responses import HTMLResponse, StreamingResponse
12
+ from fastrtc import (
13
+ AdditionalOutputs,
14
+ ReplyOnPause,
15
+ Stream,
16
+ WebRTCError,
17
+ get_twilio_turn_credentials,
18
+ stt,
19
+ )
20
+ from gradio.utils import get_space
21
+ from pydantic import BaseModel
22
+
23
+ load_dotenv()
24
+
25
+ curr_dir = Path(__file__).parent
26
+
27
+
28
+ client = openai.OpenAI(
29
+ api_key=os.environ.get("SAMBANOVA_API_KEY"),
30
+ base_url="https://api.sambanova.ai/v1",
31
+ )
32
+
33
+
34
+ def response(
35
+ audio: tuple[int, np.ndarray],
36
+ gradio_chatbot: list[dict] | None = None,
37
+ conversation_state: list[dict] | None = None,
38
+ ):
39
+ gradio_chatbot = gradio_chatbot or []
40
+ conversation_state = conversation_state or []
41
+
42
+ text = stt(audio)
43
+ sample_rate, array = audio
44
+ gradio_chatbot.append(
45
+ {"role": "user", "content": gr.Audio((sample_rate, array.squeeze()))}
46
+ )
47
+ yield AdditionalOutputs(gradio_chatbot, conversation_state)
48
+
49
+ conversation_state.append({"role": "user", "content": text})
50
+
51
+ try:
52
+ request = client.chat.completions.create(
53
+ model="Meta-Llama-3.2-3B-Instruct",
54
+ messages=conversation_state, # type: ignore
55
+ temperature=0.1,
56
+ top_p=0.1,
57
+ )
58
+ response = {"role": "assistant", "content": request.choices[0].message.content}
59
+
60
+ except Exception:
61
+ import traceback
62
+
63
+ traceback.print_exc()
64
+ raise WebRTCError(traceback.format_exc())
65
+
66
+ conversation_state.append(response)
67
+ gradio_chatbot.append(response)
68
+
69
+ yield AdditionalOutputs(gradio_chatbot, conversation_state)
70
+
71
+
72
+ chatbot = gr.Chatbot(type="messages", value=[])
73
+ state = gr.State(value=[])
74
+ stream = Stream(
75
+ ReplyOnPause(
76
+ response, # type: ignore
77
+ input_sample_rate=16000,
78
+ ),
79
+ mode="send",
80
+ modality="audio",
81
+ additional_inputs=[chatbot, state],
82
+ additional_outputs=[chatbot, state],
83
+ additional_outputs_handler=lambda *a: (a[2], a[3]),
84
+ concurrency_limit=20 if get_space() else None,
85
+ )
86
+
87
+ app = FastAPI()
88
+ stream.mount(app)
89
+
90
+
91
+ class Message(BaseModel):
92
+ role: str
93
+ content: str
94
+
95
+
96
+ class InputData(BaseModel):
97
+ webrtc_id: str
98
+ chatbot: list[Message]
99
+ state: list[Message]
100
+
101
+
102
+ @app.get("/")
103
+ async def _():
104
+ rtc_config = get_twilio_turn_credentials() if get_space() else None
105
+ html_content = (curr_dir / "index.html").read_text()
106
+ html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
107
+ return HTMLResponse(content=html_content)
108
+
109
+
110
+ @app.post("/input_hook")
111
+ async def _(data: InputData):
112
+ body = data.model_dump()
113
+ stream.set_input(data.webrtc_id, body["chatbot"], body["state"])
114
+
115
+
116
+ def audio_to_base64(file_path):
117
+ audio_format = "wav"
118
+ with open(file_path, "rb") as audio_file:
119
+ encoded_audio = base64.b64encode(audio_file.read()).decode("utf-8")
120
+ return f"data:audio/{audio_format};base64,{encoded_audio}"
121
+
122
+
123
+ @app.get("/outputs")
124
+ async def _(webrtc_id: str):
125
+ async def output_stream():
126
+ async for output in stream.output_stream(webrtc_id):
127
+ chatbot = output.args[0]
128
+ state = output.args[1]
129
+ data = {
130
+ "message": state[-1],
131
+ "audio": audio_to_base64(chatbot[-1]["content"].value["path"])
132
+ if chatbot[-1]["role"] == "user"
133
+ else None,
134
+ }
135
+ yield f"event: output\ndata: {json.dumps(data)}\n\n"
136
+
137
+ return StreamingResponse(output_stream(), media_type="text/event-stream")
138
+
139
+
140
+ if __name__ == "__main__":
141
+ import os
142
+
143
+ if (mode := os.getenv("MODE")) == "UI":
144
+ stream.ui.launch(server_port=7860)
145
+ elif mode == "PHONE":
146
+ raise ValueError("Phone mode not supported")
147
+ else:
148
+ import uvicorn
149
+
150
+ uvicorn.run(app, host="0.0.0.0", port=7860)
index.html ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>Talk to Sambanova</title>
8
+ <style>
9
+ body {
10
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
11
+ background-color: #f8f9fa;
12
+ color: #1a1a1a;
13
+ margin: 0;
14
+ padding: 20px;
15
+ height: 100vh;
16
+ box-sizing: border-box;
17
+ }
18
+
19
+ .container {
20
+ max-width: 800px;
21
+ margin: 0 auto;
22
+ height: 80%;
23
+ }
24
+
25
+ .logo {
26
+ text-align: center;
27
+ margin-bottom: 40px;
28
+ }
29
+
30
+ .chat-container {
31
+ background: white;
32
+ border-radius: 8px;
33
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
34
+ padding: 20px;
35
+ height: 90%;
36
+ box-sizing: border-box;
37
+ display: flex;
38
+ flex-direction: column;
39
+ }
40
+
41
+ .chat-messages {
42
+ flex-grow: 1;
43
+ overflow-y: auto;
44
+ margin-bottom: 20px;
45
+ padding: 10px;
46
+ }
47
+
48
+ .message {
49
+ margin-bottom: 20px;
50
+ padding: 12px;
51
+ border-radius: 8px;
52
+ font-size: 14px;
53
+ line-height: 1.5;
54
+ }
55
+
56
+ .message.user {
57
+ background-color: #e9ecef;
58
+ margin-left: 20%;
59
+ }
60
+
61
+ .message.assistant {
62
+ background-color: #f1f3f5;
63
+ margin-right: 20%;
64
+ }
65
+
66
+ .controls {
67
+ text-align: center;
68
+ margin-top: 20px;
69
+ }
70
+
71
+ button {
72
+ background-color: #0066cc;
73
+ color: white;
74
+ border: none;
75
+ padding: 12px 24px;
76
+ font-family: inherit;
77
+ font-size: 14px;
78
+ cursor: pointer;
79
+ transition: all 0.3s;
80
+ border-radius: 4px;
81
+ font-weight: 500;
82
+ }
83
+
84
+ button:hover {
85
+ background-color: #0052a3;
86
+ }
87
+
88
+ #audio-output {
89
+ display: none;
90
+ }
91
+
92
+ .icon-with-spinner {
93
+ display: flex;
94
+ align-items: center;
95
+ justify-content: center;
96
+ gap: 12px;
97
+ min-width: 180px;
98
+ }
99
+
100
+ .spinner {
101
+ width: 20px;
102
+ height: 20px;
103
+ border: 2px solid #ffffff;
104
+ border-top-color: transparent;
105
+ border-radius: 50%;
106
+ animation: spin 1s linear infinite;
107
+ flex-shrink: 0;
108
+ }
109
+
110
+ @keyframes spin {
111
+ to {
112
+ transform: rotate(360deg);
113
+ }
114
+ }
115
+
116
+ .pulse-container {
117
+ display: flex;
118
+ align-items: center;
119
+ justify-content: center;
120
+ gap: 12px;
121
+ min-width: 180px;
122
+ }
123
+
124
+ .pulse-circle {
125
+ width: 20px;
126
+ height: 20px;
127
+ border-radius: 50%;
128
+ background-color: #ffffff;
129
+ opacity: 0.2;
130
+ flex-shrink: 0;
131
+ transform: translateX(-0%) scale(var(--audio-level, 1));
132
+ transition: transform 0.1s ease;
133
+ }
134
+
135
+ /* Add styles for typing indicator */
136
+ .typing-indicator {
137
+ padding: 8px;
138
+ background-color: #f1f3f5;
139
+ border-radius: 8px;
140
+ margin-bottom: 10px;
141
+ display: none;
142
+ }
143
+
144
+ .dots {
145
+ display: inline-flex;
146
+ gap: 4px;
147
+ }
148
+
149
+ .dot {
150
+ width: 8px;
151
+ height: 8px;
152
+ background-color: #0066cc;
153
+ border-radius: 50%;
154
+ animation: pulse 1.5s infinite;
155
+ opacity: 0.5;
156
+ }
157
+
158
+ .dot:nth-child(2) {
159
+ animation-delay: 0.5s;
160
+ }
161
+
162
+ .dot:nth-child(3) {
163
+ animation-delay: 1s;
164
+ }
165
+
166
+ @keyframes pulse {
167
+
168
+ 0%,
169
+ 100% {
170
+ opacity: 0.5;
171
+ transform: scale(1);
172
+ }
173
+
174
+ 50% {
175
+ opacity: 1;
176
+ transform: scale(1.2);
177
+ }
178
+ }
179
+ </style>
180
+ </head>
181
+
182
+ <body>
183
+ <div class="container">
184
+ <div class="logo">
185
+ <h1>Talk to Sambanova 🗣️</h1>
186
+ <h2 style="font-size: 1.2em; color: #666; margin-top: 10px;">Speak to Llama 3.2 powered by Sambanova API
187
+ </h2>
188
+ </div>
189
+ <div class="chat-container">
190
+ <div class="chat-messages" id="chat-messages"></div>
191
+ <div class="typing-indicator" id="typing-indicator">
192
+ <div class="dots">
193
+ <div class="dot"></div>
194
+ <div class="dot"></div>
195
+ <div class="dot"></div>
196
+ </div>
197
+ </div>
198
+ </div>
199
+ <div class="controls">
200
+ <button id="start-button">Start Conversation</button>
201
+ </div>
202
+ </div>
203
+ <audio id="audio-output"></audio>
204
+
205
+ <script>
206
+ let peerConnection;
207
+ let webrtc_id;
208
+ const startButton = document.getElementById('start-button');
209
+ const chatMessages = document.getElementById('chat-messages');
210
+
211
+ let audioLevel = 0;
212
+ let animationFrame;
213
+ let audioContext, analyser, audioSource;
214
+ let messages = [];
215
+ let eventSource;
216
+
217
+ function updateButtonState() {
218
+ const button = document.getElementById('start-button');
219
+ if (peerConnection && (peerConnection.connectionState === 'connecting' || peerConnection.connectionState === 'new')) {
220
+ button.innerHTML = `
221
+ <div class="icon-with-spinner">
222
+ <div class="spinner"></div>
223
+ <span>Connecting...</span>
224
+ </div>
225
+ `;
226
+ } else if (peerConnection && peerConnection.connectionState === 'connected') {
227
+ button.innerHTML = `
228
+ <div class="pulse-container">
229
+ <div class="pulse-circle"></div>
230
+ <span>Stop Conversation</span>
231
+ </div>
232
+ `;
233
+ } else {
234
+ button.innerHTML = 'Start Conversation';
235
+ }
236
+ }
237
+
238
+ function setupAudioVisualization(stream) {
239
+ audioContext = new (window.AudioContext || window.webkitAudioContext)();
240
+ analyser = audioContext.createAnalyser();
241
+ audioSource = audioContext.createMediaStreamSource(stream);
242
+ audioSource.connect(analyser);
243
+ analyser.fftSize = 64;
244
+ const dataArray = new Uint8Array(analyser.frequencyBinCount);
245
+
246
+ function updateAudioLevel() {
247
+ analyser.getByteFrequencyData(dataArray);
248
+ const average = Array.from(dataArray).reduce((a, b) => a + b, 0) / dataArray.length;
249
+ audioLevel = average / 255;
250
+
251
+ const pulseCircle = document.querySelector('.pulse-circle');
252
+ if (pulseCircle) {
253
+ pulseCircle.style.setProperty('--audio-level', 1 + audioLevel);
254
+ }
255
+
256
+ animationFrame = requestAnimationFrame(updateAudioLevel);
257
+ }
258
+ updateAudioLevel();
259
+ }
260
+
261
+ function handleMessage(event) {
262
+ const eventJson = JSON.parse(event.data);
263
+ const typingIndicator = document.getElementById('typing-indicator');
264
+
265
+ if (eventJson.type === "send_input") {
266
+ fetch('/input_hook', {
267
+ method: 'POST',
268
+ headers: {
269
+ 'Content-Type': 'application/json',
270
+ },
271
+ body: JSON.stringify({
272
+ webrtc_id: webrtc_id,
273
+ chatbot: messages,
274
+ state: messages
275
+ })
276
+ });
277
+ } else if (eventJson.type === "log") {
278
+ if (eventJson.data === "pause_detected") {
279
+ typingIndicator.style.display = 'block';
280
+ chatMessages.scrollTop = chatMessages.scrollHeight;
281
+ } else if (eventJson.data === "response_starting") {
282
+ typingIndicator.style.display = 'none';
283
+ }
284
+ }
285
+ }
286
+
287
+ async function setupWebRTC() {
288
+ const config = __RTC_CONFIGURATION__;
289
+ peerConnection = new RTCPeerConnection(config);
290
+
291
+ try {
292
+ const stream = await navigator.mediaDevices.getUserMedia({
293
+ audio: true
294
+ });
295
+
296
+ setupAudioVisualization(stream);
297
+
298
+ stream.getTracks().forEach(track => {
299
+ peerConnection.addTrack(track, stream);
300
+ });
301
+
302
+ const dataChannel = peerConnection.createDataChannel('text');
303
+ dataChannel.onmessage = handleMessage;
304
+
305
+ const offer = await peerConnection.createOffer();
306
+ await peerConnection.setLocalDescription(offer);
307
+
308
+ await new Promise((resolve) => {
309
+ if (peerConnection.iceGatheringState === "complete") {
310
+ resolve();
311
+ } else {
312
+ const checkState = () => {
313
+ if (peerConnection.iceGatheringState === "complete") {
314
+ peerConnection.removeEventListener("icegatheringstatechange", checkState);
315
+ resolve();
316
+ }
317
+ };
318
+ peerConnection.addEventListener("icegatheringstatechange", checkState);
319
+ }
320
+ });
321
+
322
+ peerConnection.addEventListener('connectionstatechange', () => {
323
+ console.log('connectionstatechange', peerConnection.connectionState);
324
+ updateButtonState();
325
+ });
326
+
327
+ webrtc_id = Math.random().toString(36).substring(7);
328
+
329
+ const response = await fetch('/webrtc/offer', {
330
+ method: 'POST',
331
+ headers: { 'Content-Type': 'application/json' },
332
+ body: JSON.stringify({
333
+ sdp: peerConnection.localDescription.sdp,
334
+ type: peerConnection.localDescription.type,
335
+ webrtc_id: webrtc_id
336
+ })
337
+ });
338
+
339
+ const serverResponse = await response.json();
340
+ await peerConnection.setRemoteDescription(serverResponse);
341
+
342
+ eventSource = new EventSource('/outputs?webrtc_id=' + webrtc_id);
343
+ eventSource.addEventListener("output", (event) => {
344
+ const eventJson = JSON.parse(event.data);
345
+ console.log(eventJson);
346
+ messages.push(eventJson.message);
347
+ addMessage(eventJson.message.role, eventJson.audio ?? eventJson.message.content);
348
+ });
349
+ } catch (err) {
350
+ console.error('Error setting up WebRTC:', err);
351
+ }
352
+ }
353
+
354
+ function addMessage(role, content) {
355
+ const messageDiv = document.createElement('div');
356
+ messageDiv.classList.add('message', role);
357
+
358
+ if (role === 'user') {
359
+ // Create audio element for user messages
360
+ const audio = document.createElement('audio');
361
+ audio.controls = true;
362
+ audio.src = content;
363
+ messageDiv.appendChild(audio);
364
+ } else {
365
+ // Text content for assistant messages
366
+ messageDiv.textContent = content;
367
+ }
368
+
369
+ chatMessages.appendChild(messageDiv);
370
+ chatMessages.scrollTop = chatMessages.scrollHeight;
371
+ }
372
+
373
+ function stop() {
374
+ if (eventSource) {
375
+ eventSource.close();
376
+ eventSource = null;
377
+ }
378
+
379
+ if (animationFrame) {
380
+ cancelAnimationFrame(animationFrame);
381
+ }
382
+ if (audioContext) {
383
+ audioContext.close();
384
+ audioContext = null;
385
+ analyser = null;
386
+ audioSource = null;
387
+ }
388
+ if (peerConnection) {
389
+ if (peerConnection.getTransceivers) {
390
+ peerConnection.getTransceivers().forEach(transceiver => {
391
+ if (transceiver.stop) {
392
+ transceiver.stop();
393
+ }
394
+ });
395
+ }
396
+
397
+ if (peerConnection.getSenders) {
398
+ peerConnection.getSenders().forEach(sender => {
399
+ if (sender.track && sender.track.stop) sender.track.stop();
400
+ });
401
+ }
402
+ peerConnection.close();
403
+ }
404
+ updateButtonState();
405
+ audioLevel = 0;
406
+ }
407
+
408
+ startButton.addEventListener('click', () => {
409
+ if (!peerConnection || peerConnection.connectionState !== 'connected') {
410
+ setupWebRTC();
411
+ } else {
412
+ stop();
413
+ }
414
+ });
415
+ </script>
416
+ </body>
417
+
418
+ </html>
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastrtc[vad]
2
+ python-dotenv
3
+ openai
4
+ twilio