gemini-webrtc / app.py
mgokg's picture
Update app.py
6904a70 verified
import asyncio
import base64
import os
import time
from io import BytesIO
import gradio as gr
import numpy as np
import websockets
from dotenv import load_dotenv
from fastrtc import (
AsyncAudioVideoStreamHandler,
Stream,
WebRTC,
get_cloudflare_turn_credentials_async,
wait_for_item,
)
from google import genai
from google.genai import types # Import the types module
from gradio.utils import get_space
from PIL import Image
load_dotenv()
system_message = "you are a helpful assistant."
#system_message = "Du bist ein echzeitübersetzer. übersetze deutsch auf italienisch und italienisch auf deutsch. erkläre nichts, kommentiere nichts, füge nichts hinzu, nur übersetzen."
def encode_audio(data: np.ndarray) -> dict:
"""Encode Audio data to send to the server"""
return {
"mime_type": "audio/pcm",
"data": base64.b64encode(data.tobytes()).decode("UTF-8"),
}
def encode_image(data: np.ndarray) -> dict:
with BytesIO() as output_bytes:
pil_image = Image.fromarray(data)
pil_image.save(output_bytes, "JPEG")
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return {"mime_type": "image/jpeg", "data": base64_str}
class GeminiHandler(AsyncAudioVideoStreamHandler):
def __init__(
self,
) -> None:
super().__init__(
"mono",
output_sample_rate=24000,
input_sample_rate=16000,
)
self.audio_queue = asyncio.Queue()
self.video_queue = asyncio.Queue()
self.session = None
self.last_frame_time = 0
self.quit = asyncio.Event()
def copy(self) -> "GeminiHandler":
return GeminiHandler()
async def start_up(self):
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"}
)
# Define the tools and system instruction
tools = [
types.Tool(google_search=types.GoogleSearch()),
]
system_instruction = types.Content(
parts=[types.Part.from_text(text=f"{system_message}")],
role="user"
)
# Update the config to include tools and system_instruction
config = types.LiveConnectConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
language_code="de-DE",
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="Kore")
)
),
tools=tools,
system_instruction=system_instruction,
)
async with client.aio.live.connect(
#model="gemini-2.0-flash-exp",
model = "models/gemini-2.5-flash-preview-native-audio-dialog",
config=config, # type: ignore
) as session:
self.session = session
while not self.quit.is_set():
turn = self.session.receive()
try:
async for response in turn:
# Check if data exists before trying to process it as audio
if data := response.data:
audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
self.audio_queue.put_nowait(audio) # Only put if audio was created
# You might want to handle other parts of the response here
# e.g., response.text, response.tool_code, etc.
# For now, we just ensure we don't crash if data is None.
except websockets.exceptions.ConnectionClosedOK:
print("connection closed")
break
except Exception as e:
# Catch other potential errors during response processing
print(f"Error processing response: {e}")
# Depending on the error, you might want to break or continue
# For now, let's break to prevent infinite loops on persistent errors
break
async def video_receive(self, frame: np.ndarray):
self.video_queue.put_nowait(frame)
if self.session:
# send image every 1 second
print(time.time() - self.last_frame_time)
if time.time() - self.last_frame_time > 1:
self.last_frame_time = time.time()
await self.session.send(input=encode_image(frame))
if self.latest_args[1] is not None:
await self.session.send(input=encode_image(self.latest_args[1]))
async def video_emit(self):
frame = await wait_for_item(self.video_queue, 0.01)
if frame is not None:
return frame
else:
return np.zeros((100, 100, 3), dtype=np.uint8)
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
_, array = frame
array = array.squeeze()
audio_message = encode_audio(array)
# Add a check to ensure the session is still active before sending
if self.session:# and not self.session._ws.close: # Check if session exists and websocket is not closed
try:
await self.session.send(input=audio_message)
except websockets.exceptions.ConnectionClosedOK:
print("Attempted to send on a closed connection.")
except Exception as e:
print(f"Error sending audio message: {e}")
else:
print("Session not active, cannot send audio message.")
async def emit(self):
array = await wait_for_item(self.audio_queue, 0.01)
if array is not None:
return (self.output_sample_rate, array)
return array
async def shutdown(self) -> None:
if self.session:
self.quit.set()
await self.session.close()
self.quit.clear()
stream = Stream(
handler=GeminiHandler(),
modality="audio-video",
mode="send-receive",
rtc_configuration=get_cloudflare_turn_credentials_async,
time_limit=1800 if get_space() else None,
additional_inputs=[
gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
],
ui_args={
"icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
"pulse_color": "rgb(255, 255, 255)",
"icon_button_color": "rgb(255, 255, 255)",
"title": "Gemini Audio Video Chat",
},
)
css = """
#video-source {max-width: 500px !important; max-height: 500px !important; background-color: #0f0f11 }
#video-source video {
background-color: black !important;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<div>
<center>
</center>
</div>
"""
)
with gr.Row() as row:
with gr.Column():
webrtc = WebRTC(
label="Voice Chat",
modality="audio",
mode="send-receive",
elem_id="video-source",
rtc_configuration=get_cloudflare_turn_credentials_async,
icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
pulse_color="rgb(255, 255, 255)",
icon_button_color="rgb(255, 255, 255)",
)
#with gr.Column():
#image_input = gr.Image(
#label="Image", type="numpy", sources=["upload", "clipboard"]
#)
webrtc.stream(
GeminiHandler(),
inputs=[webrtc],
outputs=[webrtc],
time_limit=1800 if get_space() else None,
concurrency_limit=2 if get_space() else None,
)
stream.ui = demo
if __name__ == "__main__":
if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
raise ValueError("Phone mode not supported for this demo")
else:
stream.ui.launch(server_port=7860)