#| export
import copy
import os

import gradio as gr
from fastcore.utils import in_jupyter
from langchain.chains import ConversationChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
)
from PIL import Image

import constants
from lv_recipe_chatbot.engineer_prompt import INIT_PROMPT
from lv_recipe_chatbot.ingredient_vision import (
    SAMPLE_IMG_DIR,
    BlipImageCaptioning,
    VeganIngredientFinder,
    format_image,
)

# | export


class ConversationBot:
    def __init__(self, verbose=True):
        self.chat = ChatOpenAI(temperature=1, verbose=True)
        self.memory = ConversationBufferMemory(return_messages=True)
        self.init_prompt = copy.deepcopy(INIT_PROMPT)
        init_prompt_msgs = self.init_prompt.messages
        self.ai_prompt_questions = {
            "ingredients": init_prompt_msgs[1],
            "allergies": init_prompt_msgs[3],
            "recipe_open_params": init_prompt_msgs[5],
        }
        self.img_cap = BlipImageCaptioning("cpu")
        self.vegan_ingred_finder = VeganIngredientFinder()
        self.verbose = verbose

    def respond(self, user_msg, chat_history):
        response = self._get_bot_response(user_msg, chat_history)
        chat_history.append((user_msg, response))
        return "", chat_history

    def init_conversation(self, formatted_chat_prompt):
        self.conversation = ConversationChain(
            llm=self.chat,
            memory=self.memory,
            prompt=formatted_chat_prompt,
            verbose=self.verbose,
        )

    def reset(self):
        self.memory.clear()
        self.init_prompt = copy.deepcopy(INIT_PROMPT)

    def run_img(self, image: str):
        desc = self.img_cap.inference(format_image(image))
        answer = self.vegan_ingred_finder.list_ingredients(image)
        msg = f"""I uploaded an image that may contain vegan ingredients.
The description of the image is: `{desc}`.
The extracted ingredients are:
```
{answer}
```
"""
        base_prompt = INIT_PROMPT.messages[2].prompt.template
        new_prompt = f"{msg}I may type some more ingredients below.\n{base_prompt}"
        self.init_prompt.messages[2].prompt.template = new_prompt
        return msg

    def _get_bot_response(self, user_msg: str, chat_history) -> str:
        if len(chat_history) < 2:
            return self.ai_prompt_questions["allergies"].prompt.template

        if len(chat_history) < 3:
            return self.ai_prompt_questions["recipe_open_params"].prompt.template

        if len(chat_history) < 4:
            user = 0
            ai = 1
            user_msgs = [msg_pair[user] for msg_pair in chat_history[1:]]
            f_init_prompt = self.init_prompt.format_prompt(
                ingredients=user_msgs[0],
                allergies=user_msgs[1],
                recipe_freeform_input=user_msg,
            )
            chat_msgs = f_init_prompt.to_messages()
            results = self.chat.generate([chat_msgs])
            chat_msgs.extend(
                [
                    results.generations[0][0].message,
                    MessagesPlaceholder(variable_name="history"),
                    HumanMessagePromptTemplate.from_template("{input}"),
                ]
            )
            open_prompt = ChatPromptTemplate.from_messages(chat_msgs) # prepare the open conversation chain from this point
            self.init_conversation(open_prompt)
            return results.generations[0][0].message.content

        response = self.conversation.predict(input=user_msg)
        return response

    def __del__(self):
        del self.vegan_ingred_finder

from dotenv import load_dotenv

load_dotenv()

bot = ConversationBot()

print(bot.run_img(SAMPLE_IMG_DIR / "veggie-fridge.jpeg")) clear = gr.Button(\"Clear\")\n", " gr.Markdown(\"**🔃Refresh the page to start from scratch🔃**\")\n", "\n", " msg.submit(\n", " fn=bot.respond, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False\n", " )\n", " # clear.click(lambda: None, None, chatbot, queue=False).then(bot.reset)\n", " return demo" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closing server running on port: 7860\n", "Running on local URL:\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#| eval: false\n", "if \"demo\" in globals():\n", " demo.close()\n", "demo = create_demo(bot)\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "import nbdev\n", "\n", "nbdev.nbdev_export()" ] } ], "metadata": { "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }