Spaces:
Running
Running
import asyncio | |
import base64 | |
import os | |
import time | |
from io import BytesIO | |
import gradio as gr | |
import numpy as np | |
import websockets | |
from dotenv import load_dotenv | |
from fastrtc import ( | |
AsyncAudioVideoStreamHandler, | |
Stream, | |
WebRTCError, | |
get_hf_turn_credentials, | |
wait_for_item, | |
) | |
from google import genai | |
from PIL import Image | |
load_dotenv() | |
def encode_audio(data: np.ndarray) -> dict: | |
"""Encode Audio data to send to the server""" | |
return { | |
"mime_type": "audio/pcm", | |
"data": base64.b64encode(data.tobytes()).decode("UTF-8"), | |
} | |
def encode_image(data: np.ndarray) -> dict: | |
with BytesIO() as output_bytes: | |
pil_image = Image.fromarray(data) | |
pil_image.save(output_bytes, "JPEG") | |
bytes_data = output_bytes.getvalue() | |
base64_str = str(base64.b64encode(bytes_data), "utf-8") | |
return {"mime_type": "image/jpeg", "data": base64_str} | |
class GeminiHandler(AsyncAudioVideoStreamHandler): | |
def __init__( | |
self, | |
) -> None: | |
super().__init__( | |
"mono", | |
output_sample_rate=24000, | |
input_sample_rate=16000, | |
) | |
self.audio_queue = asyncio.Queue() | |
self.video_queue = asyncio.Queue() | |
self.session = None | |
self.last_frame_time = 0 | |
self.quit = asyncio.Event() | |
def copy(self) -> "GeminiHandler": | |
return GeminiHandler() | |
async def start_up(self): | |
await self.wait_for_args() | |
api_key = self.latest_args[3] | |
hf_token = self.latest_args[4] | |
if hf_token is None or hf_token == "": | |
raise WebRTCError("HF Token is required") | |
os.environ["HF_TOKEN"] = hf_token | |
client = genai.Client( | |
api_key=api_key, http_options={"api_version": "v1alpha"} | |
) | |
config = {"response_modalities": ["AUDIO"], "system_instruction": "You are an art critic that will critique the artwork passed in as an image to the user. Critique the artwork in a funny and lighthearted way. Be concise and to the point. Be friendly and engaging. Be helpful and informative. Be funny and lighthearted. Be concise and to the point. Be friendly and engaging."} | |
async with client.aio.live.connect( | |
model="gemini-2.0-flash-exp", | |
config=config, # type: ignore | |
) as session: | |
self.session = session | |
while not self.quit.is_set(): | |
turn = self.session.receive() | |
try: | |
async for response in turn: | |
if data := response.data: | |
audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1) | |
self.audio_queue.put_nowait(audio) | |
except websockets.exceptions.ConnectionClosedOK: | |
print("connection closed") | |
break | |
async def video_receive(self, frame: np.ndarray): | |
self.video_queue.put_nowait(frame) | |
if self.session: | |
# send image every 1 second | |
print(time.time() - self.last_frame_time) | |
if time.time() - self.last_frame_time > 1: | |
self.last_frame_time = time.time() | |
await self.session.send(input=encode_image(frame)) | |
if self.latest_args[2] is not None: | |
await self.session.send(input=encode_image(self.latest_args[2])) | |
async def video_emit(self): | |
frame = await wait_for_item(self.video_queue, 0.01) | |
if frame is not None: | |
return frame | |
else: | |
return np.zeros((100, 100, 3), dtype=np.uint8) | |
async def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
_, array = frame | |
array = array.squeeze() | |
audio_message = encode_audio(array) | |
if self.session: | |
await self.session.send(input=audio_message) | |
async def emit(self): | |
array = await wait_for_item(self.audio_queue, 0.01) | |
if array is not None: | |
return (self.output_sample_rate, array) | |
return array | |
async def shutdown(self) -> None: | |
if self.session: | |
self.quit.set() | |
await self.session.close() | |
self.quit.clear() | |
gemini_handler = GeminiHandler() | |
stream = Stream( | |
handler=gemini_handler, | |
modality="audio-video", | |
mode="send-receive", | |
server_rtc_configuration=get_hf_turn_credentials(ttl=600*10000), | |
rtc_configuration=get_hf_turn_credentials(), | |
additional_inputs=[ | |
gr.Markdown( | |
"## 🎨 Art Critic\n\n" | |
"Provide an image of your artwork or hold it up to the webcam and Gemini will critique it for you." | |
"To get a Gemini API key, please visit the [Gemini API Key](https://aistudio.google.com/apikey) page." | |
"To get an HF Token, please visit the [HF Token](https://huggingface.co/settings/tokens) page. The token requires Read access." | |
), | |
gr.Image(label="Artwork", value="mona_lisa.jpg", type="numpy", sources=["upload", "clipboard"]), | |
gr.Textbox(label="Gemini API Key", type="password"), | |
gr.Textbox(label="HF Token", type="password"), | |
], | |
ui_args={ | |
"icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png", | |
"pulse_color": "rgb(255, 255, 255)", | |
"icon_button_color": "rgb(255, 255, 255)", | |
"title": "Gemini Audio Video Chat", | |
}, | |
time_limit=90, | |
concurrency_limit=5, | |
) | |
if __name__ == "__main__": | |
stream.ui.launch() | |