import asyncio
import base64
import time
from io import BytesIO

import gradio as gr
import numpy as np
import google.generativeai as genai
from gradio_webrtc import (
    AsyncAudioVideoStreamHandler,
    WebRTC,
    async_aggregate_bytes_to_16bit,
    VideoEmitType,
    AudioEmitType,
)
from PIL import Image

# WARNING: Embedding API keys directly is not secure.
# Use environment variables or a secrets manager in production.
GOOGLE_API_KEY = "AIzaSyC7dAwSyLKaVO2E-PA6UaacLZ4aLGtrXbY"

# Prompt base
PROMPT_BASE = """
Você é um assistente de voz e vídeo chamado "Aura" e está em um programa de chat. Sua principal tarefa é auxiliar os usuários com uma conversação em tempo real.
Você tem as seguintes características e instruções:

### Persona:
*   Nome: Aura
*   Função: Assistente de voz e vídeo interativo
*   Personalidade: Educada, amigável, prestativa e entusiasmada.
*   Tom de voz: Calmo e reconfortante.
*   Conhecimentos: Você tem acesso a um amplo conhecimento geral.
*   Objetivo: Fornecer respostas concisas e precisas, enquanto mantém a conversa fluida e natural.

### Instruções:
*   Formato de Resposta: Responda em frases curtas e diretas com áudio. Evite respostas longas.
*   Tratamento ao Usuário: Use a segunda pessoa (você) e cumprimente sempre o usuário ao começar e encerrar uma sessão.
*   Reconhecimento de Áudio e Vídeo: Você receberá áudio e video e deverá responder com audio.
*   Gerações de respostas: Não gere respostas que sejam ofensivas, discriminatórias ou que possam causar qualquer dano ou desconforto.
*   Limitações: Se você não souber a resposta, diga "Desculpe, não sei a resposta para isso".
*   Início da conversa: Inicie a conversa sempre com "Olá, em que posso ajudar?"
*  Finalização da conversa: Finalize a conversa sempre com "Foi um prazer conversar com você!"
*  Evite: Não discuta assuntos políticos ou religiosos. Não responda perguntas sobre assuntos pessoais.
*  Contexto: Lembre-se de que você está em um chat de áudio e vídeo em tempo real.
*  Adaptabilidade: Lembre-se de que você pode receber uma imagem e um áudio.

### Início da Conversa:
Olá, em que posso ajudar?
"""


def encode_audio(data: np.ndarray) -> dict:
    """Encode Audio data to send to the server"""
    return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")}


def encode_image(data: np.ndarray) -> dict:
    with BytesIO() as output_bytes:
        pil_image = Image.fromarray(data)
        pil_image.save(output_bytes, "JPEG")
        bytes_data = output_bytes.getvalue()
    base64_str = str(base64.b64encode(bytes_data), "utf-8")
    return {"mime_type": "image/jpeg", "data": base64_str}


class GeminiHandler(AsyncAudioVideoStreamHandler):
    def __init__(
        self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
    ) -> None:
        super().__init__(
            expected_layout,
            output_sample_rate,
            output_frame_size,
            input_sample_rate=16000,
        )
        self.audio_queue = asyncio.Queue()
        self.video_queue = asyncio.Queue()
        self.quit = asyncio.Event()
        self.session = None
        self.last_frame_time = 0
        self.conversation_history = []  # Added conversation history
        self.latest_text = ""


    def copy(self) -> "GeminiHandler":
        return GeminiHandler(
            expected_layout=self.expected_layout,
            output_sample_rate=self.output_sample_rate,
            output_frame_size=self.output_frame_size,
        )
    
    async def video_receive(self, frame: np.ndarray):
        if self.session:
            # send image every 1 second
            if time.time() - self.last_frame_time > 1:
                self.last_frame_time = time.time()
                await self.session.send(encode_image(frame))
                if self.latest_args[2] is not None:
                    await self.session.send(encode_image(self.latest_args[2]))
        self.video_queue.put_nowait(frame)
    
    async def video_emit(self) -> VideoEmitType:
        return await self.video_queue.get()

    async def connect(self, api_key: str):
        if self.session is None:
            client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
            config = {"response_modalities": ["AUDIO"]}
            async with client.aio.live.connect(
                model="gemini-2.0-flash-exp", config=config
            ) as session:
                self.session = session
                asyncio.create_task(self.receive_audio())
                await self.quit.wait()

    async def generator(self):
        while not self.quit.is_set():
            turn = self.session.receive()
            async for response in turn:
                if data := response.data:
                    yield data
    
    async def receive_audio(self):
        async for audio_response in async_aggregate_bytes_to_16bit(
            self.generator()
        ):
            self.audio_queue.put_nowait(audio_response)

    async def receive(self, frame: tuple[int, np.ndarray], text_input: str) -> None:  # Added text_input here
        _, array = frame
        array = array.squeeze()
        if self.session:
            if text_input: # Checks if text was inputted
                full_prompt = PROMPT_BASE + "\n\n" + "User: " + text_input
                await self.session.send({"mime_type": "text", "data": full_prompt})
                self.conversation_history.append({"role": "user", "content": text_input})  # Add text conversation
            elif array.size: # Checks if audio was received
                full_prompt = PROMPT_BASE + "\n\n" + "User: " + str(base64.b64encode(array.tobytes()).decode("UTF-8"))
                await self.session.send({"mime_type": "text", "data": full_prompt})
                self.conversation_history.append({"role": "user", "content": str(base64.b64encode(array.tobytes()).decode("UTF-8"))})

    async def emit(self) -> AudioEmitType:
        if not self.args_set.is_set():
            await self.wait_for_args()
        if self.session is None:
            asyncio.create_task(self.connect(self.latest_args[1]))
        array = await self.audio_queue.get()
        return (self.output_sample_rate, array)
    
    def set_text(self, text):
       self.latest_text = text
    
    def clear_text(self):
        self.latest_text = ""
        return ""
        

    def shutdown(self) -> None:
        self.quit.set()
        self.connection = None
        self.args_set.clear()
        self.quit.clear()



css = """
#video-source {max-width: 600px !important; max-height: 600 !important;}
"""

with gr.Blocks(css=css) as demo:
    gr.HTML(
        """
    <div style='display: flex; align-items: center; justify-content: center; gap: 20px'>
        <div style="background-color: var(--block-background-fill); border-radius: 8px">
            <img src="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png" style="width: 100px; height: 100px;">
        </div>
        <div>
            <h1>Gen AI SDK Voice Chat</h1>
            <p>Speak with Gemini using real-time audio + video streaming</p>
            <p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href=https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>⚡️</p>
            <p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
        </div>
    </div>
    """
    )
    with gr.Row() as api_key_row:
        api_key = gr.Textbox(label="API Key", type="password", placeholder="Enter your API Key", value=GOOGLE_API_KEY, visible=False)
    with gr.Row() as row:
        with gr.Column():
            webrtc = WebRTC(
                label="Video Chat",
                modality="audio-video",
                mode="send-receive",
                elem_id="video-source",
                icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
                pulse_color="rgb(35, 157, 225)",
                icon_button_color="rgb(35, 157, 225)",
            )
        with gr.Column():
            image_input = gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
            text_input = gr.Textbox(label="Text Message", placeholder="Type your message here")
            send_button = gr.Button("Send")

        handler = GeminiHandler()
        send_button.click(handler.set_text, inputs=[text_input], outputs=[])
        send_button.click(handler.clear_text, inputs=[], outputs=[text_input])
        webrtc.stream(
            handler,
            inputs=[webrtc, api_key, image_input, text_input],
            outputs=[webrtc],
            time_limit=90,
            concurrency_limit=2,
        )


if __name__ == "__main__":
    demo.launch()