File size: 3,510 Bytes
adaea7c
 
 
 
 
 
 
 
ae2bbf3
adaea7c
ae2bbf3
adaea7c
ae2bbf3
 
adaea7c
ae2bbf3
 
 
 
adaea7c
 
 
ae2bbf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adaea7c
 
 
 
 
 
 
ae2bbf3
adaea7c
 
ae2bbf3
adaea7c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_app.ipynb.

# %% auto 0
__all__ = ['ConversationBot', 'launch_demo']

# %% ../nbs/01_app.ipynb 3
import os

import gradio as gr
from fastcore.utils import in_jupyter
from langchain.chains import ConversationChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
)

from .engineer_prompt import init_prompt

# %% ../nbs/01_app.ipynb 4
class ConversationBot:
    def __init__(
        self,
    ):
        self.chat = ChatOpenAI(temperature=1, verbose=True)
        self.memory = ConversationBufferMemory(return_messages=True)
        self.init_prompt_msgs = init_prompt.messages
        self.ai_prompt_questions = {
            "ingredients": self.init_prompt_msgs[1],
            "allergies": self.init_prompt_msgs[3],
            "recipe_open_params": self.init_prompt_msgs[5],
        }

    def respond(self, user_msg, chat_history):
        response = self._get_bot_response(user_msg, chat_history)
        chat_history.append((user_msg, response))
        return "", chat_history

    def init_conversation(self, formatted_chat_prompt):
        self.conversation = ConversationChain(
            llm=self.chat,
            memory=self.memory,
            prompt=formatted_chat_prompt,
            verbose=True,
        )

    def reset(self):
        self.memory.clear()

    def _get_bot_response(self, user_msg: str, chat_history) -> str:
        if len(chat_history) < 2:
            return self.ai_prompt_questions["allergies"].prompt.template

        if len(chat_history) < 3:
            return self.ai_prompt_questions["recipe_open_params"].prompt.template

        if len(chat_history) < 4:
            user = 0
            ai = 1
            user_msgs = [msg_pair[user] for msg_pair in chat_history[1:]]
            f_init_prompt = init_prompt.format_prompt(
                ingredients=user_msgs[0],
                allergies=user_msgs[1],
                recipe_freeform_input=user_msg,
            )
            chat_msgs = f_init_prompt.to_messages()
            results = self.chat.generate([chat_msgs])
            chat_msgs.extend(
                [
                    results.generations[0][0].message,
                    MessagesPlaceholder(variable_name="history"),
                    HumanMessagePromptTemplate.from_template("{input}"),
                ]
            )
            open_prompt = ChatPromptTemplate.from_messages(chat_msgs)
            # prepare the open conversation chain from this point
            self.init_conversation(open_prompt)
            return results.generations[0][0].message.content

        response = self.conversation.predict(input=user_msg)
        return response

# %% ../nbs/01_app.ipynb 5
def launch_demo():
    with gr.Blocks() as demo:
        bot = ConversationBot()
        chatbot = gr.Chatbot(
            value=[(None, bot.ai_prompt_questions["ingredients"].prompt.template)]
        )

        msg = gr.Textbox()
        clear = gr.Button("Clear")

        msg.submit(
            fn=bot.respond, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
        )
        clear.click(lambda: None, None, chatbot, queue=False).then(bot.reset)
        demo.launch(
            auth=(
                os.environ["GRADIO_DEMO_USERNAME"],
                os.environ["GRADIO_DEMO_PASSWORD"],
            )
        )