chef-transformer / server.py
m3hrdadfi's picture
Add an alternative for st-tags
9d7c7b4
raw history blame
No virus
3.51 kB
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from datetime import datetime as dt
import streamlit as st
from streamlit_tags import st_tags
import beam_search
import top_sampling
from pprint import pprint
import json
with open("config.json") as f:
cfg = json.loads(f.read())
st.set_page_config(layout="wide")
@st.cache(allow_output_mutation=True)
def load_model():
tokenizer = AutoTokenizer.from_pretrained(cfg["model_name_or_path"])
model = AutoModelForSeq2SeqLM.from_pretrained(cfg["model_name_or_path"])
generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
return generator, tokenizer
def sampling_changed(obj):
print(obj)
with st.spinner('Loading model...'):
generator, tokenizer = load_model()
# st.image("images/chef-transformer.png", width=400)
st.header("Chef Transformer πŸ‘©β€πŸ³ / πŸ‘¨β€πŸ³")
st.markdown(
"This demo uses [T5 trained on recipe-nlg](https://huggingface.co/flax-community/t5-recipe-generation) "
"to generate recipe from a given set of ingredients"
)
img = st.sidebar.image("images/chef-transformer-transparent.png", width=310)
add_text_sidebar = st.sidebar.title("Popular recipes:")
add_text_sidebar = st.sidebar.text("Recipe preset(example#1)")
add_text_sidebar = st.sidebar.text("Recipe preset(example#2)")
add_text_sidebar = st.sidebar.title("Mode:")
sampling_mode = st.sidebar.selectbox("select a Mode", index=0, options=["Top Sampling", "Beam Search"])
original_keywords = st.multiselect(
"Choose ingredients",
cfg["first_100"],
["parmesan cheese", "fresh oregano", "basil", "whole wheat flour"]
)
# st.write("Add custom ingredients here:")
# custom_keywords = st_tags(
# label="",
# text='Press enter to add more',
# value=['salt'],
# suggestions=["z"],
# maxtags=15,
# key='1')
def custom_keywords_on_change():
pass
custom_keywords = st.text_input(
'Add custom ingredients here (separated by `,`): ',
", ".join(["salt", "pepper"]),
key="custom_keywords",
on_change=custom_keywords_on_change,
max_chars=1000)
custom_keywords = list(set([x.strip() for x in custom_keywords.strip().split(',') if len(x.strip()) > 0]))
all_ingredients = []
all_ingredients.extend(original_keywords)
all_ingredients.extend(custom_keywords)
all_ingredients = ", ".join(all_ingredients)
st.markdown("**Generate recipe for:** " + all_ingredients)
submit = st.button('Get Recipe!')
if submit:
with st.spinner('Generating recipe...'):
if sampling_mode == "Beam Search":
generated = generator(all_ingredients, return_tensors=True, return_text=False,
**beam_search.generate_kwargs)
outputs = beam_search.post_generator(generated, tokenizer)
elif sampling_mode == "Top-k Sampling":
generated = generator(all_ingredients, return_tensors=True, return_text=False,
**top_sampling.generate_kwargs)
outputs = top_sampling.post_generator(generated, tokenizer)
output = outputs[0]
output['title'] = " ".join([w.capitalize() for w in output['title'].split()])
markdown_output = ""
markdown_output += f"## {output['title']}\n"
markdown_output += f"#### Ingredients:\n"
for o in output["ingredients"]:
markdown_output += f"- {o}\n"
markdown_output += f"#### Directions:\n"
for o in output["directions"]:
markdown_output += f"- {o}\n"
st.markdown(markdown_output)