from __future__ import annotations as _annotations import json import os from contextlib import asynccontextmanager from dataclasses import dataclass from typing import AsyncGenerator import asyncpg import gradio as gr import numpy as np import pydantic_core from gradio_webrtc import ( AdditionalOutputs, ReplyOnPause, WebRTC, audio_to_bytes, get_twilio_turn_credentials, ) from groq import Groq from openai import AsyncOpenAI from pydantic import BaseModel from pydantic_ai import RunContext from pydantic_ai.agent import Agent from pydantic_ai.messages import ModelStructuredResponse, ModelTextResponse, ToolReturn DOCS = json.load(open("gradio_docs.json")) groq_client = Groq() openai = AsyncOpenAI() @dataclass class Deps: openai: AsyncOpenAI pool: asyncpg.Pool SYSTEM_PROMPT = ( "You are an assistant designed to help users answer questions about Gradio. " "You have a retrival tool that can provide relevant documentation sections based on the user query. " "Be curteous and helpful to the user but feel free to refuse answering questions that are not about Gradio. " ) agent = Agent( "openai:gpt-4o", deps_type=Deps, system_prompt=SYSTEM_PROMPT, ) class RetrievalResult(BaseModel): content: str ids: list[int] @asynccontextmanager async def database_connect( create_db: bool = False, ) -> AsyncGenerator[asyncpg.Pool, None]: server_dsn, database = ( os.getenv("DATABASE_URL"), "gradio_ai_rag", ) if create_db: conn = await asyncpg.connect(server_dsn) try: db_exists = await conn.fetchval( "SELECT 1 FROM pg_database WHERE datname = $1", database ) if not db_exists: await conn.execute(f"CREATE DATABASE {database}") finally: await conn.close() pool = await asyncpg.create_pool(f"{server_dsn}/{database}") try: yield pool finally: await pool.close() @agent.tool async def retrieve(context: RunContext[Deps], search_query: str) -> str: """Retrieve documentation sections based on a search query. Args: context: The call context. search_query: The search query. """ print(f"create embedding for {search_query}") embedding = await context.deps.openai.embeddings.create( input=search_query, model="text-embedding-3-small", ) assert ( len(embedding.data) == 1 ), f"Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}" embedding = embedding.data[0].embedding embedding_json = pydantic_core.to_json(embedding).decode() rows = await context.deps.pool.fetch( "SELECT id, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8", embedding_json, ) content = "\n\n".join(f'# {row["title"]}\n{row["content"]}\n' for row in rows) ids = [row["id"] for row in rows] return RetrievalResult(content=content, ids=ids).model_dump_json() async def stream_from_agent( audio: tuple[int, np.ndarray], chatbot: list[dict], past_messages: list ): question = groq_client.audio.transcriptions.create( file=("audio-file.mp3", audio_to_bytes(audio)), model="whisper-large-v3-turbo", response_format="verbose_json", ).text print("text", question) chatbot.append({"role": "user", "content": question}) yield AdditionalOutputs(chatbot, gr.skip()) async with database_connect(False) as pool: deps = Deps(openai=openai, pool=pool) async with agent.run_stream( question, deps=deps, message_history=past_messages ) as result: for message in result.new_messages(): past_messages.append(message) if isinstance(message, ModelStructuredResponse): for call in message.calls: gr_message = { "role": "assistant", "content": "", "metadata": { "title": "🔍 Retrieving relevant docs", "id": call.tool_id, }, } chatbot.append(gr_message) if isinstance(message, ToolReturn): for gr_message in chatbot: if ( gr_message.get("metadata", {}).get("id", "") == message.tool_id ): paths = [] for d in DOCS: tool_result = RetrievalResult.model_validate_json( message.content ) if d["id"] in tool_result.ids: paths.append(d["path"]) paths = '\n'.join(list(set(paths))) gr_message["content"] = ( f"Relevant Context:\n {paths}" ) yield AdditionalOutputs(chatbot, gr.skip()) chatbot.append({"role": "assistant", "content": ""}) async for message in result.stream_text(): chatbot[-1]["content"] = message yield AdditionalOutputs(chatbot, gr.skip()) data = await result.get_data() past_messages.append(ModelTextResponse(content=data)) yield AdditionalOutputs(gr.skip(), past_messages) with gr.Blocks() as demo: placeholder = """

Chat with Gradio Docs 🗣️

Simple RAG agent over Gradio docs built with Pydantic AI.

Ask any question about Gradio with your natural voice and get an answer!

""" past_messages = gr.State([]) chatbot = gr.Chatbot( label="Gradio Docs Bot", type="messages", placeholder=placeholder, avatar_images=(None, "logo.svg"), ) audio = WebRTC( label="Talk with the Agent", modality="audio", rtc_configuration=get_twilio_turn_credentials(), mode="send", ) audio.stream( ReplyOnPause(stream_from_agent), inputs=[audio, chatbot, past_messages], outputs=[audio], ) audio.on_additional_outputs( lambda c, s: (c, s), outputs=[chatbot, past_messages], queue=False, show_progress="hidden", ) if __name__ == "__main__": demo.launch(allowed_paths=["logo.svg"])