File size: 5,068 Bytes
adaea7c
 
 
986fa13
adaea7c
 
986fa13
adaea7c
 
ae2bbf3
adaea7c
ae2bbf3
adaea7c
ae2bbf3
 
adaea7c
ae2bbf3
 
 
986fa13
ae2bbf3
986fa13
 
 
 
 
 
 
 
adaea7c
 
ae2bbf3
986fa13
ae2bbf3
 
986fa13
 
ae2bbf3
986fa13
 
 
ae2bbf3
986fa13
 
 
ae2bbf3
 
 
 
 
 
 
 
 
 
 
986fa13
ae2bbf3
 
 
 
986fa13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae2bbf3
 
 
 
 
 
 
 
 
 
 
 
986fa13
ae2bbf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
986fa13
 
 
 
 
 
 
 
 
 
 
 
 
 
adaea7c
986fa13
 
 
 
 
 
 
 
 
adaea7c
 
 
ae2bbf3
adaea7c
986fa13
 
ae2bbf3
adaea7c
 
 
986fa13
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_app.ipynb.

# %% auto 0
__all__ = ['ConversationBot', 'create_demo']

# %% ../nbs/01_app.ipynb 3
import copy
import os

import gradio as gr
from fastcore.utils import in_jupyter
from langchain.chains import ConversationChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
)
from PIL import Image

import constants
from .engineer_prompt import INIT_PROMPT
from lv_recipe_chatbot.ingredient_vision import (
    SAMPLE_IMG_DIR,
    BlipImageCaptioning,
    VeganIngredientFinder,
    format_image,
)

# %% ../nbs/01_app.ipynb 4
class ConversationBot:
    def __init__(self, verbose=True):
        self.chat = ChatOpenAI(temperature=1, verbose=True)
        self.memory = ConversationBufferMemory(return_messages=True)
        self.init_prompt = copy.deepcopy(INIT_PROMPT)
        init_prompt_msgs = self.init_prompt.messages
        self.ai_prompt_questions = {
            "ingredients": init_prompt_msgs[1],
            "allergies": init_prompt_msgs[3],
            "recipe_open_params": init_prompt_msgs[5],
        }
        self.img_cap = BlipImageCaptioning("cpu")
        self.vegan_ingred_finder = VeganIngredientFinder()
        self.verbose = verbose

    def respond(self, user_msg, chat_history):
        response = self._get_bot_response(user_msg, chat_history)
        chat_history.append((user_msg, response))
        return "", chat_history

    def init_conversation(self, formatted_chat_prompt):
        self.conversation = ConversationChain(
            llm=self.chat,
            memory=self.memory,
            prompt=formatted_chat_prompt,
            verbose=self.verbose,
        )

    def reset(self):
        self.memory.clear()
        self.init_prompt = copy.deepcopy(INIT_PROMPT)

    def run_img(self, image: str):
        desc = self.img_cap.inference(format_image(image))
        answer = self.vegan_ingred_finder.list_ingredients(image)
        msg = f"""I uploaded an image that may contain vegan ingredients.
The description of the image is: `{desc}`.
The extracted ingredients are:
```
{answer}
```
"""
        base_prompt = INIT_PROMPT.messages[2].prompt.template
        new_prompt = f"{msg}I may type some more ingredients below.\n{base_prompt}"
        self.init_prompt.messages[2].prompt.template = new_prompt
        return msg

    def _get_bot_response(self, user_msg: str, chat_history) -> str:
        if len(chat_history) < 2:
            return self.ai_prompt_questions["allergies"].prompt.template

        if len(chat_history) < 3:
            return self.ai_prompt_questions["recipe_open_params"].prompt.template

        if len(chat_history) < 4:
            user = 0
            ai = 1
            user_msgs = [msg_pair[user] for msg_pair in chat_history[1:]]
            f_init_prompt = self.init_prompt.format_prompt(
                ingredients=user_msgs[0],
                allergies=user_msgs[1],
                recipe_freeform_input=user_msg,
            )
            chat_msgs = f_init_prompt.to_messages()
            results = self.chat.generate([chat_msgs])
            chat_msgs.extend(
                [
                    results.generations[0][0].message,
                    MessagesPlaceholder(variable_name="history"),
                    HumanMessagePromptTemplate.from_template("{input}"),
                ]
            )
            open_prompt = ChatPromptTemplate.from_messages(chat_msgs)
            # prepare the open conversation chain from this point
            self.init_conversation(open_prompt)
            return results.generations[0][0].message.content

        response = self.conversation.predict(input=user_msg)
        return response

    def __del__(self):
        del self.vegan_ingred_finder

# %% ../nbs/01_app.ipynb 13
def create_demo(bot=ConversationBot):
    sample_images = []
    all_imgs = [f"{SAMPLE_IMG_DIR}/{img}" for img in os.listdir(SAMPLE_IMG_DIR)]
    for i, img in enumerate(all_imgs):
        if i in [
            1,
            2,
            3,
        ]:
            sample_images.append(img)
    with gr.Blocks() as demo:
        gr_img = gr.Image(type="filepath")
        btn = gr.Button(value="Submit image")
        ingredients_msg = gr.Text(label="Ingredients from image")
        btn.click(bot.run_img, inputs=[gr_img], outputs=[ingredients_msg])
        gr.Examples(
            examples=sample_images,
            inputs=gr_img,
        )

        chatbot = gr.Chatbot(
            value=[(None, bot.ai_prompt_questions["ingredients"].prompt.template)]
        )

        msg = gr.Textbox()
        # clear = gr.Button("Clear")
        gr.Markdown("**🔃Refresh the page to start from scratch🔃**")

        msg.submit(
            fn=bot.respond, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
        )
        # clear.click(lambda: None, None, chatbot, queue=False).then(bot.reset)
        return demo