{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# app\n", "\n", "> Gradio app.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| default_exp app" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# | export\n", "import copy\n", "import os\n", "import gradio as gr\n", "import constants\n", "from lv_recipe_chatbot.vegan_recipe_assistant import (\n", " SYSTEM_PROMPT,\n", " vegan_recipe_edamam_search,\n", " VEGAN_RECIPE_SEARCH_TOOL_SCHEMA,\n", ")\n", "from openai import OpenAI, AssistantEventHandler\n", "from typing_extensions import override\n", "import json" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "import time\n", "from dotenv import load_dotenv" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#: eval: false\n", "load_dotenv()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Need an even handler to send the streaming output to the Gradio app \n", "[GPT4 streaming output example on hugging face 🤗](https://huggingface.co/spaces/ysharma/ChatGPT4/blob/main/app.pyhttps://huggingface.co/spaces/ysharma/ChatGPT4/blob/main/app.py) \n", "[Gradio lite let's you insert Gradio app in browser JS](https://www.gradio.app/guides/gradio-litehttps://www.gradio.app/guides/gradio-lite) \n", "[Streaming output](https://www.gradio.app/main/guides/streaming-outputshttps://www.gradio.app/main/guides/streaming-outputs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class EventHandler(AssistantEventHandler):\n", " def __init__(self, handle_text_delta):\n", " self.handle_text_delta = handle_text_delta\n", "\n", " @override\n", " def on_text_delta(self, delta, snapshot):\n", " self.handle_text_delta(delta.value)\n", "\n", " @override\n", " def on_event(self, event):\n", " # Retrieve events that are denoted with 'requires_action'\n", " # since these will have our tool_calls\n", " if event.event == \"thread.run.requires_action\":\n", " run_id = event.data.id # Retrieve the run ID from the event data\n", " self.handle_requires_action(event.data, run_id)\n", "\n", " def handle_requires_action(self, data, run_id):\n", " tool_outputs = []\n", " for tool_call in data.required_action.submit_tool_outputs.tool_calls:\n", " if tool_call.function.name == \"vegan_recipe_edamam_search\":\n", " fn_args = json.loads(tool_call.function.arguments)\n", " data = vegan_recipe_edamam_search(\n", " query=fn_args.get(\"query\"),\n", " )\n", " tool_outputs.append({\"tool_call_id\": tool_call.id, \"output\": data})\n", "\n", " self.submit_tool_outputs(tool_outputs, run_id)\n", "\n", " def submit_tool_outputs(self, tool_outputs, run_id):\n", " with client.beta.threads.runs.submit_tool_outputs_stream(\n", " thread_id=self.current_run.thread_id,\n", " run_id=self.current_run.id,\n", " tool_outputs=tool_outputs,\n", " event_handler=EventHandler(),\n", " ) as stream:\n", " for text in stream.until_:\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "client = OpenAI()\n", "assistant = client.beta.assistants.create(\n", " name=\"Vegan Recipe Finder\",\n", " instructions=SYSTEM_PROMPT\n", " + \"\\nChoose the best single matching recipe to the user's query out of the vegan recipe search returned recipes\",\n", " model=\"gpt-4o\",\n", " tools=[VEGAN_RECIPE_SEARCH_TOOL_SCHEMA],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def run_conversation() -> str:\n", " run = client.beta.threads.runs.create_and_poll(\n", " thread_id=thread.id,\n", " assistant_id=assistant.id,\n", " )\n", " while True:\n", " tool_outputs = []\n", " tool_calls = (\n", " []\n", " if not run.required_action\n", " else run.required_action.submit_tool_outputs.tool_calls\n", " )\n", "\n", " for tool_call in tool_calls:\n", " if tool_call.function.name == \"vegan_recipe_edamam_search\":\n", " fn_args = json.loads(tool_call.function.arguments)\n", " data = vegan_recipe_edamam_search(\n", " query=fn_args.get(\"query\"),\n", " )\n", " tool_outputs.append({\"tool_call_id\": tool_call.id, \"output\": data})\n", "\n", " if tool_outputs:\n", " try:\n", " run = client.beta.threads.runs.submit_tool_outputs_and_poll(\n", " thread_id=thread.id,\n", " run_id=run.id,\n", " tool_outputs=tool_outputs,\n", " )\n", " print(\"Tool outputs submitted successfully.\")\n", "\n", " except Exception as e:\n", " print(\"Failed to submit tool outputs:\", e)\n", " return \"Sorry failed to run tools. Try again with a different query.\"\n", "\n", " if run.status == \"completed\":\n", " messages = client.beta.threads.messages.list(thread_id=thread.id)\n", " data = messages.data\n", " content = data[0].content\n", " return content[0].text.value\n", " time.sleep(0.05)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closing server running on port: 7860\n", "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" }, { "name": "stdout", "output_type": "stream", "text": [ "Tool outputs submitted successfully.\n", "Tool outputs submitted successfully.\n", "Tool outputs submitted successfully.\n", "Tool outputs submitted successfully.\n", "Tool outputs submitted successfully.\n", "Tool outputs submitted successfully.\n" ] } ], "source": [ "# https://www.gradio.app/main/guides/creating-a-chatbot-fast#customizing-your-chatbot\n", "\n", "\n", "# on chatbot start/ first msg after clear\n", "thread = client.beta.threads.create()\n", "\n", "\n", "def predict(message, history):\n", " # if msg no new file handle it as such\n", " # note that history is a flat list of text messages\n", " txt = message[\"text\"]\n", " if txt:\n", " client.beta.threads.messages.create(\n", " thread_id=thread.id,\n", " role=\"user\",\n", " content=txt,\n", " )\n", " files = message[\"files\"]\n", " # files is only from the last message rather than all historically submitted files\n", " if files:\n", " # files[-1].split(\".\")[-1] in [\"jpg\", \"png\", \"jpeg\", \"webp\"]:\n", " file = message[\"files\"][-1]\n", " file = client.files.create(\n", " file=open(\n", " file,\n", " \"rb\",\n", " ),\n", " purpose=\"vision\",\n", " )\n", " client.beta.threads.messages.create(\n", " thread_id=thread.id,\n", " content=[\n", " {\n", " \"type\": \"text\",\n", " \"text\": \"What vegan ingredients do you see in this image?\",\n", " },\n", " {\"type\": \"image_file\", \"image_file\": {\"file_id\": file.id}},\n", " ],\n", " role=\"user\",\n", " )\n", " return run_conversation()\n", "\n", "\n", "# print(predict({\"text\": \"yo\", \"files\": []}, []))\n", "# print(predict({\"text\": \"suggest a tofu and greens recipe please\", \"files\": []}, []))\n", "# print(predict({\"text\": \"burger\", \"files\": []}, []))\n", "if \"demo\" in globals():\n", " demo.close()\n", "\n", "demo = gr.ChatInterface(fn=predict, multimodal=True)\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "skip\n" ] } ], "source": [ "def create_demo():\n", " # sample_images = []\n", " # all_imgs = [f\"{SAMPLE_IMG_DIR}/{img}\" for img in os.listdir(SAMPLE_IMG_DIR)]\n", " # for i, img in enumerate(all_imgs):\n", " # if i in [\n", " # 1,\n", " # 2,\n", " # 3,\n", " # ]:\n", " # sample_images.append(img)\n", " with gr.ChatInterface() as demo:\n", " # gr_img = gr.Image(type=\"filepath\")\n", " # btn = gr.Button(value=\"Submit image\")\n", " # ingredients_msg = gr.Text(label=\"Ingredients from image\")\n", " # btn.click(bot.run_img, inputs=[gr_img], outputs=[ingredients_msg])\n", " # gr.Examples(\n", " # examples=sample_images,\n", " # inputs=gr_img,\n", " # )\n", "\n", " chatbot = gr.Chatbot(value=[(None,)])\n", "\n", " msg = gr.Textbox()\n", " gr.Markdown(\n", " \"\"\"**🔃Refresh the page to start from scratch🔃** \n", " \n", " Recipe search tool powered by the [Edamam API](https://www.edamam.com/) \n", " \n", " ![Edamam Logo](https://www.edamam.com/assets/img/small-logo.png)\"\"\"\n", " )\n", " msg.submit(\n", " fn=bot.respond, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False\n", " )\n", " # clear.click(lambda: None, None, chatbot, queue=False).then(bot.reset)\n", " return demo" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "import nbdev\n", "\n", "nbdev.nbdev_export()" ] } ], "metadata": { "kernelspec": { "display_name": "local-lv-chatbot", "language": "python", "name": "local-lv-chatbot" } }, "nbformat": 4, "nbformat_minor": 4 }