{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# app\n", "\n", "> Gradio app.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| default_exp app" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# | export\n", "import copy\n", "import os\n", "import gradio as gr\n", "import constants\n", "from lv_recipe_chatbot.vegan_recipe_assistant import (\n", " SYSTEM_PROMPT,\n", " vegan_recipe_edamam_search,\n", " VEGAN_RECIPE_SEARCH_TOOL_SCHEMA,\n", ")\n", "from openai import OpenAI, AssistantEventHandler\n", "from typing_extensions import override\n", "import json" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "import time\n", "from dotenv import load_dotenv" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#: eval: false\n", "load_dotenv()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Need an even handler to send the streaming output to the Gradio app \n", "[GPT4 streaming output example on hugging face 🤗](https://huggingface.co/spaces/ysharma/ChatGPT4/blob/main/app.pyhttps://huggingface.co/spaces/ysharma/ChatGPT4/blob/main/app.py) \n", "[Gradio lite let's you insert Gradio app in browser JS](https://www.gradio.app/guides/gradio-litehttps://www.gradio.app/guides/gradio-lite) \n", "[Streaming output](https://www.gradio.app/main/guides/streaming-outputshttps://www.gradio.app/main/guides/streaming-outputs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class EventHandler(AssistantEventHandler):\n", " def __init__(self, handle_text_delta):\n", " self.handle_text_delta = handle_text_delta\n", "\n", " @override\n", " def on_text_delta(self, delta, snapshot):\n", " self.handle_text_delta(delta.value)\n", "\n", " @override\n", " def on_event(self, event):\n", " # Retrieve events that are denoted with 'requires_action'\n", " # since these will have our tool_calls\n", " if event.event == \"thread.run.requires_action\":\n", " run_id = event.data.id # Retrieve the run ID from the event data\n", " self.handle_requires_action(event.data, run_id)\n", "\n", " def handle_requires_action(self, data, run_id):\n", " tool_outputs = []\n", " for tool_call in data.required_action.submit_tool_outputs.tool_calls:\n", " if tool_call.function.name == \"vegan_recipe_edamam_search\":\n", " fn_args = json.loads(tool_call.function.arguments)\n", " data = vegan_recipe_edamam_search(\n", " query=fn_args.get(\"query\"),\n", " )\n", " tool_outputs.append({\"tool_call_id\": tool_call.id, \"output\": data})\n", "\n", " self.submit_tool_outputs(tool_outputs, run_id)\n", "\n", " def submit_tool_outputs(self, tool_outputs, run_id):\n", " with client.beta.threads.runs.submit_tool_outputs_stream(\n", " thread_id=self.current_run.thread_id,\n", " run_id=self.current_run.id,\n", " tool_outputs=tool_outputs,\n", " event_handler=EventHandler(),\n", " ) as stream:\n", " for text in stream.until_:\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "client = OpenAI()\n", "assistant = client.beta.assistants.create(\n", " name=\"Vegan Recipe Finder\",\n", " instructions=SYSTEM_PROMPT\n", " + \"\\nChoose the best single matching recipe to the user's query out of the vegan recipe search returned recipes\",\n", " model=\"gpt-4o\",\n", " tools=[VEGAN_RECIPE_SEARCH_TOOL_SCHEMA],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def run_conversation() -> str:\n", " run = client.beta.threads.runs.create_and_poll(\n", " thread_id=thread.id,\n", " assistant_id=assistant.id,\n", " )\n", " while True:\n", " tool_outputs = []\n", " tool_calls = (\n", " []\n", " if not run.required_action\n", " else run.required_action.submit_tool_outputs.tool_calls\n", " )\n", "\n", " for tool_call in tool_calls:\n", " if tool_call.function.name == \"vegan_recipe_edamam_search\":\n", " fn_args = json.loads(tool_call.function.arguments)\n", " data = vegan_recipe_edamam_search(\n", " query=fn_args.get(\"query\"),\n", " )\n", " tool_outputs.append({\"tool_call_id\": tool_call.id, \"output\": data})\n", "\n", " if tool_outputs:\n", " try:\n", " run = client.beta.threads.runs.submit_tool_outputs_and_poll(\n", " thread_id=thread.id,\n", " run_id=run.id,\n", " tool_outputs=tool_outputs,\n", " )\n", " print(\"Tool outputs submitted successfully.\")\n", "\n", " except Exception as e:\n", " print(\"Failed to submit tool outputs:\", e)\n", " return \"Sorry failed to run tools. Try again with a different query.\"\n", "\n", " if run.status == \"completed\":\n", " messages = client.beta.threads.messages.list(thread_id=thread.id)\n", " data = messages.data\n", " content = data[0].content\n", " return content[0].text.value\n", " time.sleep(0.05)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closing server running on port: 7860\n", "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "