File size: 3,576 Bytes
39f82cb
9d7c7b4
39f82cb
 
 
 
 
 
 
 
 
 
 
 
9d7c7b4
39f82cb
 
9d7c7b4
 
39f82cb
 
 
9d7c7b4
39f82cb
 
9d7c7b4
39f82cb
 
 
 
9d7c7b4
 
 
 
 
 
39f82cb
 
 
 
 
9d7c7b4
39f82cb
9d7c7b4
 
39f82cb
 
 
 
91c6692
 
 
 
 
 
 
 
9d7c7b4
3d35750
 
9d7c7b4
 
91c6692
 
 
 
 
 
 
9d7c7b4
39f82cb
 
 
 
9d7c7b4
39f82cb
 
 
 
 
3d35750
 
 
 
 
39f82cb
3d35750
 
 
 
 
 
39f82cb
 
9d7c7b4
39f82cb
9d7c7b4
39f82cb
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from datetime import datetime as dt
import streamlit as st
from streamlit_tags import st_tags
import beam_search
import top_sampling
from pprint import pprint
import json

with open("config.json") as f:
    cfg = json.loads(f.read())

st.set_page_config(layout="wide")


@st.cache(allow_output_mutation=True)
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(cfg["model_name_or_path"])
    model = AutoModelForSeq2SeqLM.from_pretrained(cfg["model_name_or_path"])
    generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
    return generator, tokenizer


def sampling_changed(obj):
    print(obj)


with st.spinner('Loading model...'):
    generator, tokenizer = load_model()
# st.image("images/chef-transformer.png", width=400)
st.header("Chef Transformer πŸ‘©β€πŸ³ / πŸ‘¨β€πŸ³")
st.markdown(
    "This demo uses [T5 trained on recipe-nlg](https://huggingface.co/flax-community/t5-recipe-generation) "
    "to generate recipe from a given set of ingredients"
)
img = st.sidebar.image("images/chef-transformer-transparent.png", width=310)
add_text_sidebar = st.sidebar.title("Popular recipes:")
add_text_sidebar = st.sidebar.text("Recipe preset(example#1)")
add_text_sidebar = st.sidebar.text("Recipe preset(example#2)")

add_text_sidebar = st.sidebar.title("Mode:")
sampling_mode = st.sidebar.selectbox("select a Mode", index=0, options=["Top Sampling", "Beam Search"])

original_keywords = st.multiselect(
    "Choose ingredients",
    cfg["first_100"],
    ["parmesan cheese", "fresh oregano", "basil", "whole wheat flour"]
)

st.write("Add custom ingredients here:")
custom_keywords = st_tags(
    label="",
    text='Press enter to add more',
    value=['salt'],
    suggestions=["z"],
    maxtags=15,
    key='1')

# def custom_keywords_on_change():
#     pass


# custom_keywords = st.text_input(
#     'Add custom ingredients here (separated by `,`): ',
#     ", ".join(["salt", "pepper"]),
#     key="custom_keywords",
#     # on_change=custom_keywords_on_change,
#     max_chars=1000)
# custom_keywords = list(set([x.strip() for x in custom_keywords.strip().split(',') if len(x.strip()) > 0]))

all_ingredients = []
all_ingredients.extend(original_keywords)
all_ingredients.extend(custom_keywords)
all_ingredients = ", ".join(all_ingredients)
st.markdown("**Generate recipe for:** " + all_ingredients)

submit = st.button('Get Recipe!')
if submit:
    with st.spinner('Generating recipe...'):
        if sampling_mode == "Beam Search":
            generated = generator(
                all_ingredients,
                return_tensors=True,
                return_text=False,
                **beam_search.generate_kwargs)
            outputs = beam_search.post_generator(generated, tokenizer)
        elif sampling_mode == "Top Sampling":
            generated = generator(
                all_ingredients,
                return_tensors=True,
                return_text=False,
                **top_sampling.generate_kwargs)
            outputs = top_sampling.post_generator(generated, tokenizer)
    output = outputs[0]
    output['title'] = " ".join([w.capitalize() for w in output['title'].split()])
    markdown_output = ""
    markdown_output += f"## {output['title']}\n"
    markdown_output += f"#### Ingredients:\n"
    for o in output["ingredients"]:
        markdown_output += f"- {o}\n"
    markdown_output += f"#### Directions:\n"
    for o in output["directions"]:
        markdown_output += f"- {o}\n"
    st.markdown(markdown_output)