Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import asyncio | |
| import base64 | |
| import json | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import openai | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI | |
| from fastapi.responses import HTMLResponse, StreamingResponse | |
| from fastrtc import ( | |
| AdditionalOutputs, | |
| AsyncStreamHandler, | |
| Stream, | |
| get_twilio_turn_credentials, | |
| wait_for_item, | |
| ) | |
| from gradio.utils import get_space | |
| load_dotenv() | |
| cur_dir = Path(__file__).parent | |
| SAMPLE_RATE = 24000 | |
| class OpenAIHandler(AsyncStreamHandler): | |
| def __init__( | |
| self, | |
| ) -> None: | |
| super().__init__( | |
| expected_layout="mono", | |
| output_sample_rate=SAMPLE_RATE, | |
| input_sample_rate=SAMPLE_RATE, | |
| ) | |
| self.connection = None | |
| self.output_queue = asyncio.Queue() | |
| def copy(self): | |
| return OpenAIHandler() | |
| async def start_up( | |
| self, | |
| ): | |
| """Connect to realtime API. Run forever in separate thread to keep connection open.""" | |
| self.client = openai.AsyncOpenAI() | |
| async with self.client.beta.realtime.connect( | |
| model="gpt-4o-mini-realtime-preview-2024-12-17" | |
| ) as conn: | |
| await conn.session.update( | |
| session={ | |
| "turn_detection": {"type": "server_vad"}, | |
| "input_audio_transcription": { | |
| "model": "whisper-1", | |
| "language": "en", | |
| }, | |
| } | |
| ) | |
| self.connection = conn | |
| async for event in self.connection: | |
| # Handle interruptions | |
| if event.type == "input_audio_buffer.speech_started": | |
| self.clear_queue() | |
| if ( | |
| event.type | |
| == "conversation.item.input_audio_transcription.completed" | |
| ): | |
| await self.output_queue.put( | |
| AdditionalOutputs({"role": "user", "content": event.transcript}) | |
| ) | |
| if event.type == "response.audio_transcript.done": | |
| await self.output_queue.put( | |
| AdditionalOutputs( | |
| {"role": "assistant", "content": event.transcript} | |
| ) | |
| ) | |
| if event.type == "response.audio.delta": | |
| await self.output_queue.put( | |
| ( | |
| self.output_sample_rate, | |
| np.frombuffer( | |
| base64.b64decode(event.delta), dtype=np.int16 | |
| ).reshape(1, -1), | |
| ), | |
| ) | |
| async def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
| if not self.connection: | |
| return | |
| _, array = frame | |
| array = array.squeeze() | |
| audio_message = base64.b64encode(array.tobytes()).decode("utf-8") | |
| await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore | |
| async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None: | |
| return await wait_for_item(self.output_queue) | |
| async def shutdown(self) -> None: | |
| if self.connection: | |
| await self.connection.close() | |
| self.connection = None | |
| def update_chatbot(chatbot: list[dict], response: dict): | |
| chatbot.append(response) | |
| return chatbot | |
| chatbot = gr.Chatbot(type="messages") | |
| latest_message = gr.Textbox(type="text", visible=False) | |
| stream = Stream( | |
| OpenAIHandler(), | |
| mode="send-receive", | |
| modality="audio", | |
| additional_inputs=[chatbot], | |
| additional_outputs=[chatbot], | |
| additional_outputs_handler=update_chatbot, | |
| rtc_configuration=get_twilio_turn_credentials() if get_space() else None, | |
| concurrency_limit=5 if get_space() else None, | |
| time_limit=90 if get_space() else None, | |
| ) | |
| app = FastAPI() | |
| stream.mount(app) | |
| async def _(): | |
| rtc_config = get_twilio_turn_credentials() if get_space() else None | |
| html_content = (cur_dir / "index.html").read_text() | |
| html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)) | |
| return HTMLResponse(content=html_content) | |
| def _(webrtc_id: str): | |
| async def output_stream(): | |
| import json | |
| async for output in stream.output_stream(webrtc_id): | |
| s = json.dumps(output.args[0]) | |
| yield f"event: output\ndata: {s}\n\n" | |
| return StreamingResponse(output_stream(), media_type="text/event-stream") | |
| if __name__ == "__main__": | |
| import os | |
| if (mode := os.getenv("MODE")) == "UI": | |
| stream.ui.launch(server_port=7860) | |
| elif mode == "PHONE": | |
| stream.fastphone(host="0.0.0.0", port=7860) | |
| else: | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |