Spaces:
Running
Running
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_app.ipynb. | |
# %% auto 0 | |
__all__ = ['ConversationBot', 'create_demo'] | |
# %% ../nbs/01_app.ipynb 3 | |
import copy | |
import os | |
import gradio as gr | |
from langchain import LLMChain, OpenAI, PromptTemplate | |
from langchain.agents import ( | |
AgentExecutor, | |
AgentType, | |
OpenAIFunctionsAgent, | |
Tool, | |
initialize_agent, | |
load_tools, | |
) | |
from langchain.chains import ConversationChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.memory import ChatMessageHistory, ConversationBufferMemory | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
MessagesPlaceholder, | |
) | |
from PIL import Image | |
import constants | |
from .engineer_prompt import INIT_PROMPT | |
from lv_recipe_chatbot.ingredient_vision import ( | |
SAMPLE_IMG_DIR, | |
BlipImageCaptioning, | |
VeganIngredientFinder, | |
format_image, | |
) | |
from .vegan_recipe_tools import vegan_recipe_edamam_search | |
# %% ../nbs/01_app.ipynb 16 | |
class ConversationBot: | |
memory_key: str = "chat_history" | |
def __init__( | |
self, | |
vegan_ingred_finder: VeganIngredientFinder, | |
img_cap: BlipImageCaptioning, | |
verbose: bool = True, | |
): | |
self.llm = ChatOpenAI(temperature=0.1, verbose=verbose) | |
self.init_prompt = copy.deepcopy(INIT_PROMPT) | |
self.img_cap = img_cap | |
self.vegan_ingred_finder = vegan_ingred_finder | |
self.verbose = verbose | |
init_prompt_msgs = self.init_prompt.messages | |
self.ai_prompt_questions = { | |
"ingredients": init_prompt_msgs[1], | |
"allergies": init_prompt_msgs[3], | |
"recipe_open_params": init_prompt_msgs[5], | |
} | |
def respond(self, user_msg, chat_history): | |
response = self._get_bot_response(user_msg, chat_history) | |
chat_history.append((user_msg, response)) | |
return "", chat_history | |
def init_agent_executor(self, chat_msgs): | |
tools = [vegan_recipe_edamam_search] | |
prompt = OpenAIFunctionsAgent.create_prompt( | |
system_message=self.init_prompt.messages[0], | |
extra_prompt_messages=chat_msgs | |
+ [MessagesPlaceholder(variable_name=self.memory_key)], | |
) | |
self.memory = ConversationBufferMemory( | |
chat_memory=ChatMessageHistory(messages=chat_msgs), | |
return_messages=True, | |
memory_key=self.memory_key, | |
) | |
self.agent_executor = AgentExecutor( | |
agent=OpenAIFunctionsAgent(llm=self.llm, tools=tools, prompt=prompt), | |
tools=tools, | |
memory=self.memory, | |
verbose=True, | |
) | |
def reset(self): | |
self.memory.clear() | |
self.init_prompt = copy.deepcopy(INIT_PROMPT) | |
def run_img(self, image: str): | |
desc = self.img_cap.inference(format_image(image)) | |
answer = self.vegan_ingred_finder.list_ingredients(image) | |
msg = f"""I uploaded an image that may contain vegan ingredients. | |
The description of the image is: `{desc}`. | |
The extracted ingredients are: | |
``` | |
{answer} | |
```""" | |
base_prompt = INIT_PROMPT.messages[2].prompt.template | |
new_prompt = f"{msg}I may type some more ingredients below.\n{base_prompt}" | |
self.init_prompt.messages[2].prompt.template = new_prompt | |
return msg | |
def _get_bot_response(self, user_msg: str, chat_history) -> str: | |
if len(chat_history) < 2: | |
return self.ai_prompt_questions["allergies"].prompt.template | |
if len(chat_history) < 3: | |
return self.ai_prompt_questions["recipe_open_params"].prompt.template | |
if len(chat_history) < 4: | |
user = 0 | |
ai = 1 | |
user_msgs = [msg_pair[user] for msg_pair in chat_history[1:]] | |
f_init_prompt = self.init_prompt.format_prompt( | |
ingredients=user_msgs[0], | |
allergies=user_msgs[1], | |
recipe_freeform_input=user_msg, | |
) | |
chat_msgs = f_init_prompt.to_messages() | |
results = self.llm.generate([chat_msgs]) | |
chat_msgs.append(results.generations[0][0].message) | |
# prepare the agent to takeover from this point | |
self.init_agent_executor(chat_msgs) | |
return self.agent_executor.run("Search for a vegan recipe with that query") | |
response = self.agent_executor.run(input=user_msg) | |
return response | |
# %% ../nbs/01_app.ipynb 20 | |
def create_demo(bot: ConversationBot): | |
sample_images = [] | |
all_imgs = [f"{SAMPLE_IMG_DIR}/{img}" for img in os.listdir(SAMPLE_IMG_DIR)] | |
for i, img in enumerate(all_imgs): | |
if i in [ | |
1, | |
2, | |
3, | |
]: | |
sample_images.append(img) | |
with gr.Blocks() as demo: | |
gr_img = gr.Image(type="filepath") | |
btn = gr.Button(value="Submit image") | |
ingredients_msg = gr.Text(label="Ingredients from image") | |
btn.click(bot.run_img, inputs=[gr_img], outputs=[ingredients_msg]) | |
gr.Examples( | |
examples=sample_images, | |
inputs=gr_img, | |
) | |
chatbot = gr.Chatbot( | |
value=[(None, bot.ai_prompt_questions["ingredients"].prompt.template)] | |
) | |
msg = gr.Textbox() | |
# clear = gr.Button("Clear") | |
gr.Markdown( | |
""" | |
**🔃Refresh the page to start from scratch🔃** | |
Recipe search tool powered by the [Edamam API](https://www.edamam.com/) | |
![Edamam Logo](https://www.edamam.com/assets/img/small-logo.png) | |
""" | |
) | |
msg.submit( | |
fn=bot.respond, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False | |
) | |
# clear.click(lambda: None, None, chatbot, queue=False).then(bot.reset) | |
return demo | |