{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# app\n", "\n", "> Gradio app.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| default_exp app" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "import os\n", "\n", "import gradio as gr\n", "from fastcore.utils import in_jupyter\n", "from langchain.chains import ConversationChain\n", "from langchain.chat_models import ChatOpenAI\n", "from langchain.memory import ConversationBufferMemory\n", "from langchain.prompts.chat import (\n", " ChatPromptTemplate,\n", " HumanMessagePromptTemplate,\n", " MessagesPlaceholder,\n", ")\n", "\n", "from lv_recipe_chatbot.engineer_prompt import init_prompt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class ConversationBot:\n", " def __init__(\n", " self,\n", " ):\n", " self.chat = ChatOpenAI(temperature=1, verbose=True)\n", " self.memory = ConversationBufferMemory(return_messages=True)\n", " self.init_prompt_msgs = init_prompt.messages\n", " self.ai_prompt_questions = {\n", " \"ingredients\": self.init_prompt_msgs[1],\n", " \"allergies\": self.init_prompt_msgs[3],\n", " \"recipe_open_params\": self.init_prompt_msgs[5],\n", " }\n", "\n", " def respond(self, user_msg, chat_history):\n", " response = self._get_bot_response(user_msg, chat_history)\n", " chat_history.append((user_msg, response))\n", " return \"\", chat_history\n", "\n", " def init_conversation(self, formatted_chat_prompt):\n", " self.conversation = ConversationChain(\n", " llm=self.chat,\n", " memory=self.memory,\n", " prompt=formatted_chat_prompt,\n", " verbose=True,\n", " )\n", "\n", " def reset(self):\n", " self.memory.clear()\n", "\n", " def _get_bot_response(self, user_msg: str, chat_history) -> str:\n", " if len(chat_history) < 2:\n", " return self.ai_prompt_questions[\"allergies\"].prompt.template\n", "\n", " if len(chat_history) < 3:\n", " return self.ai_prompt_questions[\"recipe_open_params\"].prompt.template\n", "\n", " if len(chat_history) < 4:\n", " user = 0\n", " ai = 1\n", " user_msgs = [msg_pair[user] for msg_pair in chat_history[1:]]\n", " f_init_prompt = init_prompt.format_prompt(\n", " ingredients=user_msgs[0],\n", " allergies=user_msgs[1],\n", " recipe_freeform_input=user_msg,\n", " )\n", " chat_msgs = f_init_prompt.to_messages()\n", " results = self.chat.generate([chat_msgs])\n", " chat_msgs.extend(\n", " [\n", " results.generations[0][0].message,\n", " MessagesPlaceholder(variable_name=\"history\"),\n", " HumanMessagePromptTemplate.from_template(\"{input}\"),\n", " ]\n", " )\n", " open_prompt = ChatPromptTemplate.from_messages(chat_msgs)\n", " # prepare the open conversation chain from this point\n", " self.init_conversation(open_prompt)\n", " return results.generations[0][0].message.content\n", "\n", " response = self.conversation.predict(input=user_msg)\n", " return response" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def launch_demo():\n", " with gr.Blocks() as demo:\n", " bot = ConversationBot()\n", " chatbot = gr.Chatbot(\n", " value=[(None, bot.ai_prompt_questions[\"ingredients\"].prompt.template)]\n", " )\n", "\n", " msg = gr.Textbox()\n", " clear = gr.Button(\"Clear\")\n", "\n", " msg.submit(\n", " fn=bot.respond, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False\n", " )\n", " clear.click(lambda: None, None, chatbot, queue=False).then(bot.reset)\n", " demo.launch(\n", " auth=(\n", " os.environ[\"GRADIO_DEMO_USERNAME\"],\n", " os.environ[\"GRADIO_DEMO_PASSWORD\"],\n", " )\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load environment for the demo" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from dotenv import load_dotenv" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7862\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] } ], "source": [ "load_dotenv()\n", "launch_demo()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "import nbdev\n", "\n", "nbdev.nbdev_export()" ] } ], "metadata": { "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }