File size: 3,515 Bytes
39f82cb
9d7c7b4
39f82cb
 
 
 
 
 
 
 
 
 
 
 
9d7c7b4
39f82cb
 
9d7c7b4
 
39f82cb
 
 
9d7c7b4
39f82cb
 
9d7c7b4
39f82cb
 
 
 
9d7c7b4
 
 
 
 
 
39f82cb
 
 
 
 
9d7c7b4
39f82cb
9d7c7b4
 
39f82cb
 
 
 
9d7c7b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18b0e4b
9d7c7b4
 
 
39f82cb
 
 
 
9d7c7b4
39f82cb
 
 
 
 
9d7c7b4
 
39f82cb
 
9d7c7b4
 
39f82cb
 
9d7c7b4
39f82cb
9d7c7b4
39f82cb
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from datetime import datetime as dt
import streamlit as st
from streamlit_tags import st_tags
import beam_search
import top_sampling
from pprint import pprint
import json

with open("config.json") as f:
    cfg = json.loads(f.read())

st.set_page_config(layout="wide")


@st.cache(allow_output_mutation=True)
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(cfg["model_name_or_path"])
    model = AutoModelForSeq2SeqLM.from_pretrained(cfg["model_name_or_path"])
    generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
    return generator, tokenizer


def sampling_changed(obj):
    print(obj)


with st.spinner('Loading model...'):
    generator, tokenizer = load_model()
# st.image("images/chef-transformer.png", width=400)
st.header("Chef Transformer πŸ‘©β€πŸ³ / πŸ‘¨β€πŸ³")
st.markdown(
    "This demo uses [T5 trained on recipe-nlg](https://huggingface.co/flax-community/t5-recipe-generation) "
    "to generate recipe from a given set of ingredients"
)
img = st.sidebar.image("images/chef-transformer-transparent.png", width=310)
add_text_sidebar = st.sidebar.title("Popular recipes:")
add_text_sidebar = st.sidebar.text("Recipe preset(example#1)")
add_text_sidebar = st.sidebar.text("Recipe preset(example#2)")

add_text_sidebar = st.sidebar.title("Mode:")
sampling_mode = st.sidebar.selectbox("select a Mode", index=0, options=["Top Sampling", "Beam Search"])

original_keywords = st.multiselect(
    "Choose ingredients",
    cfg["first_100"],
    ["parmesan cheese", "fresh oregano", "basil", "whole wheat flour"]
)


# st.write("Add custom ingredients here:")
# custom_keywords = st_tags(
#     label="",
#     text='Press enter to add more',
#     value=['salt'],
#     suggestions=["z"],
#     maxtags=15,
#     key='1')

def custom_keywords_on_change():
    pass


custom_keywords = st.text_input(
    'Add custom ingredients here (separated by `,`): ',
    ", ".join(["salt", "pepper"]),
    key="custom_keywords",
    # on_change=custom_keywords_on_change,
    max_chars=1000)
custom_keywords = list(set([x.strip() for x in custom_keywords.strip().split(',') if len(x.strip()) > 0]))

all_ingredients = []
all_ingredients.extend(original_keywords)
all_ingredients.extend(custom_keywords)
all_ingredients = ", ".join(all_ingredients)
st.markdown("**Generate recipe for:** " + all_ingredients)

submit = st.button('Get Recipe!')
if submit:
    with st.spinner('Generating recipe...'):
        if sampling_mode == "Beam Search":
            generated = generator(all_ingredients, return_tensors=True, return_text=False,
                                  **beam_search.generate_kwargs)
            outputs = beam_search.post_generator(generated, tokenizer)
        elif sampling_mode == "Top-k Sampling":
            generated = generator(all_ingredients, return_tensors=True, return_text=False,
                                  **top_sampling.generate_kwargs)
            outputs = top_sampling.post_generator(generated, tokenizer)
    output = outputs[0]
    output['title'] = " ".join([w.capitalize() for w in output['title'].split()])
    markdown_output = ""
    markdown_output += f"## {output['title']}\n"
    markdown_output += f"#### Ingredients:\n"
    for o in output["ingredients"]:
        markdown_output += f"- {o}\n"
    markdown_output += f"#### Directions:\n"
    for o in output["directions"]:
        markdown_output += f"- {o}\n"
    st.markdown(markdown_output)