dawood's picture
dawood HF Staff
Update app.py
de49ea7 verified
import asyncio
import base64
import os
import time
from io import BytesIO
import gradio as gr
import numpy as np
import websockets
from dotenv import load_dotenv
from fastrtc import (
AsyncAudioVideoStreamHandler,
Stream,
WebRTCError,
get_hf_turn_credentials,
wait_for_item,
)
from google import genai
from PIL import Image
load_dotenv()
def encode_audio(data: np.ndarray) -> dict:
"""Encode Audio data to send to the server"""
return {
"mime_type": "audio/pcm",
"data": base64.b64encode(data.tobytes()).decode("UTF-8"),
}
def encode_image(data: np.ndarray) -> dict:
with BytesIO() as output_bytes:
pil_image = Image.fromarray(data)
pil_image.save(output_bytes, "JPEG")
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return {"mime_type": "image/jpeg", "data": base64_str}
class GeminiHandler(AsyncAudioVideoStreamHandler):
def __init__(
self,
) -> None:
super().__init__(
"mono",
output_sample_rate=24000,
input_sample_rate=16000,
)
self.audio_queue = asyncio.Queue()
self.video_queue = asyncio.Queue()
self.session = None
self.last_frame_time = 0
self.quit = asyncio.Event()
def copy(self) -> "GeminiHandler":
return GeminiHandler()
async def start_up(self):
await self.wait_for_args()
api_key = self.latest_args[3]
hf_token = self.latest_args[4]
if hf_token is None or hf_token == "":
raise WebRTCError("HF Token is required")
os.environ["HF_TOKEN"] = hf_token
client = genai.Client(
api_key=api_key, http_options={"api_version": "v1alpha"}
)
config = {"response_modalities": ["AUDIO"], "system_instruction": "You are an art critic that will critique the artwork passed in as an image to the user. Critique the artwork in a funny and lighthearted way. Be concise and to the point. Be friendly and engaging. Be helpful and informative. Be funny and lighthearted. Be concise and to the point. Be friendly and engaging."}
async with client.aio.live.connect(
model="gemini-2.0-flash-exp",
config=config, # type: ignore
) as session:
self.session = session
while not self.quit.is_set():
turn = self.session.receive()
try:
async for response in turn:
if data := response.data:
audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
self.audio_queue.put_nowait(audio)
except websockets.exceptions.ConnectionClosedOK:
print("connection closed")
break
async def video_receive(self, frame: np.ndarray):
self.video_queue.put_nowait(frame)
if self.session:
# send image every 1 second
print(time.time() - self.last_frame_time)
if time.time() - self.last_frame_time > 1:
self.last_frame_time = time.time()
await self.session.send(input=encode_image(frame))
if self.latest_args[2] is not None:
await self.session.send(input=encode_image(self.latest_args[2]))
async def video_emit(self):
frame = await wait_for_item(self.video_queue, 0.01)
if frame is not None:
return frame
else:
return np.zeros((100, 100, 3), dtype=np.uint8)
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
_, array = frame
array = array.squeeze()
audio_message = encode_audio(array)
if self.session:
await self.session.send(input=audio_message)
async def emit(self):
array = await wait_for_item(self.audio_queue, 0.01)
if array is not None:
return (self.output_sample_rate, array)
return array
async def shutdown(self) -> None:
if self.session:
self.quit.set()
await self.session.close()
self.quit.clear()
gemini_handler = GeminiHandler()
stream = Stream(
handler=gemini_handler,
modality="audio-video",
mode="send-receive",
server_rtc_configuration=get_hf_turn_credentials(ttl=600*10000),
rtc_configuration=get_hf_turn_credentials(),
additional_inputs=[
gr.Markdown(
"## 🎨 Art Critic\n\n"
"Provide an image of your artwork or hold it up to the webcam and Gemini will critique it for you."
"To get a Gemini API key, please visit the [Gemini API Key](https://aistudio.google.com/apikey) page."
"To get an HF Token, please visit the [HF Token](https://huggingface.co/settings/tokens) page. The token requires Read access."
),
gr.Image(label="Artwork", value="mona_lisa.jpg", type="numpy", sources=["upload", "clipboard"]),
gr.Textbox(label="Gemini API Key", type="password"),
gr.Textbox(label="HF Token", type="password"),
],
ui_args={
"icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
"pulse_color": "rgb(255, 255, 255)",
"icon_button_color": "rgb(255, 255, 255)",
"title": "Gemini Audio Video Chat",
},
time_limit=90,
concurrency_limit=5,
)
if __name__ == "__main__":
stream.ui.launch()