from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from datetime import datetime as dt import streamlit as st from streamlit_tags import st_tags import beam_search import top_sampling from pprint import pprint import json with open("config.json") as f: cfg = json.loads(f.read()) st.set_page_config(layout="wide") @st.cache(allow_output_mutation=True) def load_model(): tokenizer = AutoTokenizer.from_pretrained(cfg["model_name_or_path"]) model = AutoModelForSeq2SeqLM.from_pretrained(cfg["model_name_or_path"]) generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer) return generator, tokenizer def sampling_changed(obj): print(obj) with st.spinner('Loading model...'): generator, tokenizer = load_model() # st.image("images/chef-transformer.png", width=400) st.header("Chef Transformer 👩‍🍳 / 👨‍🍳") st.markdown( "This demo uses [T5 trained on recipe-nlg](https://huggingface.co/flax-community/t5-recipe-generation) " "to generate recipe from a given set of ingredients" ) img = st.sidebar.image("images/chef-transformer-transparent.png", width=310) add_text_sidebar = st.sidebar.title("Popular recipes:") add_text_sidebar = st.sidebar.text("Recipe preset(example#1)") add_text_sidebar = st.sidebar.text("Recipe preset(example#2)") add_text_sidebar = st.sidebar.title("Mode:") sampling_mode = st.sidebar.selectbox("select a Mode", index=0, options=["Top Sampling", "Beam Search"]) original_keywords = st.multiselect( "Choose ingredients", cfg["first_100"], ["parmesan cheese", "fresh oregano", "basil", "whole wheat flour"] ) # st.write("Add custom ingredients here:") # custom_keywords = st_tags( # label="", # text='Press enter to add more', # value=['salt'], # suggestions=["z"], # maxtags=15, # key='1') def custom_keywords_on_change(): pass custom_keywords = st.text_input( 'Add custom ingredients here (separated by `,`): ', ", ".join(["salt", "pepper"]), key="custom_keywords", on_change=custom_keywords_on_change, max_chars=1000) custom_keywords = list(set([x.strip() for x in custom_keywords.strip().split(',') if len(x.strip()) > 0])) all_ingredients = [] all_ingredients.extend(original_keywords) all_ingredients.extend(custom_keywords) all_ingredients = ", ".join(all_ingredients) st.markdown("**Generate recipe for:** " + all_ingredients) submit = st.button('Get Recipe!') if submit: with st.spinner('Generating recipe...'): if sampling_mode == "Beam Search": generated = generator(all_ingredients, return_tensors=True, return_text=False, **beam_search.generate_kwargs) outputs = beam_search.post_generator(generated, tokenizer) elif sampling_mode == "Top-k Sampling": generated = generator(all_ingredients, return_tensors=True, return_text=False, **top_sampling.generate_kwargs) outputs = top_sampling.post_generator(generated, tokenizer) output = outputs[0] output['title'] = " ".join([w.capitalize() for w in output['title'].split()]) markdown_output = "" markdown_output += f"## {output['title']}\n" markdown_output += f"#### Ingredients:\n" for o in output["ingredients"]: markdown_output += f"- {o}\n" markdown_output += f"#### Directions:\n" for o in output["directions"]: markdown_output += f"- {o}\n" st.markdown(markdown_output)