File size: 4,612 Bytes
adaea7c
 
 
6e5b58a
adaea7c
 
986fa13
adaea7c
ae2bbf3
986fa13
6e5b58a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_app.ipynb.

# %% auto 0
__all__ = ['handle_requires_action', 'run_convo_stream', 'predict', 'create_demo']

# %% ../nbs/01_app.ipynb 3
import copy
import os
import gradio as gr
import constants
from lv_recipe_chatbot.vegan_recipe_assistant import (
    SYSTEM_PROMPT,
    vegan_recipe_edamam_search,
    VEGAN_RECIPE_SEARCH_TOOL_SCHEMA,
)
from openai import OpenAI, AssistantEventHandler
from typing_extensions import override
import json
from functools import partial

# %% ../nbs/01_app.ipynb 9
def handle_requires_action(data):
    tool_outputs = []
    for tool_call in data.required_action.submit_tool_outputs.tool_calls:
        if tool_call.function.name == "vegan_recipe_edamam_search":
            fn_args = json.loads(tool_call.function.arguments)
            data = vegan_recipe_edamam_search(
                query=fn_args.get("query"),
            )
            tool_outputs.append({"tool_call_id": tool_call.id, "output": data})
    return tool_outputs

# %% ../nbs/01_app.ipynb 11
def run_convo_stream(thread, content: str, client: OpenAI, assistant):
    message = client.beta.threads.messages.create(
        thread_id=thread.id,
        role="user",
        content=content,
    )
    stream = client.beta.threads.runs.create(
        thread_id=thread.id,
        assistant_id=assistant.id,
        stream=True,
    )
    for event in stream:
        if event.event == "thread.message.delta":
            yield event.data.delta.content[0].text.value

        if event.event == "thread.run.requires_action":
            tool_outputs = handle_requires_action(event.data)
            stream = client.beta.threads.runs.submit_tool_outputs(
                run_id=event.data.id,
                thread_id=thread.id,
                tool_outputs=tool_outputs,
                stream=True,
            )
            for event in stream:
                if event.event == "thread.message.delta":
                    yield event.data.delta.content[0].text.value

# %% ../nbs/01_app.ipynb 13
def predict(message, history, client: OpenAI, assistant, thread):
    # note that history is a flat list of text messages
    reply = ""
    files = message["files"]
    txt = message["text"]

    if files:
        if files[-1].split(".")[-1] not in ["jpg", "png", "jpeg", "webp"]:
            return "Sorry only accept image files"

        file = message["files"][-1]
        file = client.files.create(
            file=open(
                file,
                "rb",
            ),
            purpose="vision",
        )

        for reply_txt in run_convo_stream(
            thread,
            content=[
                {
                    "type": "text",
                    "text": "What vegan ingredients do you see in this image? Also list out a few combinations of the ingredients that go well together. Lastly, suggest a recipe based on one of those combos using the vegan recipe seach tool.",
                },
                {"type": "image_file", "image_file": {"file_id": file.id}},
            ],
            client=client,
            assistant=assistant,
        ):
            reply += reply_txt
            yield reply

    elif txt:
        for reply_txt in run_convo_stream(thread, txt, client, assistant):
            reply += reply_txt
            yield reply

# %% ../nbs/01_app.ipynb 14
def create_demo(client: OpenAI, assistant):
    # https://www.gradio.app/main/guides/creating-a-chatbot-fast#customizing-your-chatbot
    # on chatbot start/ first msg after clear
    thread = client.beta.threads.create()

    # sample_images = []
    # all_imgs = [f"{SAMPLE_IMG_DIR}/{img}" for img in os.listdir(SAMPLE_IMG_DIR)]
    # for i, img in enumerate(all_imgs):
    #     if i in [
    #         1,
    #         2,
    #         3,
    #     ]:
    #         sample_images.append(img)
    pred = partial(predict, client=client, assistant=assistant, thread=thread)
    with gr.ChatInterface(
        fn=pred,
        multimodal=True,
        chatbot=gr.Chatbot(
            placeholder="Hello!\nI am a animal advocate AI that is capable of recommending vegan recipes.\nUpload an image or write a message below to get started!"
        ),
    ) as demo:
        gr.Markdown(
            """🔃 **Refresh the page to start from scratch**  
        
        Recipe search tool powered by the [Edamam API](https://www.edamam.com/)  
        
        ![Edamam Logo](https://www.edamam.com/assets/img/small-logo.png)"""
        )

        # clear.click(lambda: None, None, chatbot, queue=False).then(bot.reset)
        return demo